{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.conv_learner import *\n",
    "from fastai.dataset import *\n",
    "from fastai.models.resnet import vgg_resnet50\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.set_device(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.backends.cudnn.benchmark=True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = Path('data/carvana')\n",
    "MASKS_FN = 'train_masks.csv'\n",
    "META_FN = 'metadata.csv'\n",
    "masks_csv = pd.read_csv(PATH/MASKS_FN)\n",
    "meta_csv = pd.read_csv(PATH/META_FN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_img(im, figsize=None, ax=None, alpha=None):\n",
    "    if not ax: fig,ax = plt.subplots(figsize=figsize)\n",
    "    ax.imshow(im, alpha=alpha)\n",
    "    ax.set_axis_off()\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_DN = 'train-128'\n",
    "MASKS_DN = 'train_masks-128'\n",
    "sz = 128\n",
    "bs = 64\n",
    "nw = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_DN = 'train'\n",
    "MASKS_DN = 'train_masks_png'\n",
    "sz = 128\n",
    "bs = 64\n",
    "nw = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MatchedFilesDataset(FilesDataset):\n",
    "    def __init__(self, fnames, y, transform, path):\n",
    "        self.y=y\n",
    "        assert(len(fnames)==len(y))\n",
    "        super().__init__(fnames, transform, path)\n",
    "    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))\n",
    "    def get_c(self): return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])\n",
    "y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_idxs = list(range(1008))\n",
    "((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),\n",
    "            RandomFlip(tfm_y=TfmType.CLASS),\n",
    "            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
    "datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)\n",
    "md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
    "denorm = md.trn_ds.denorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(md.trn_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple upsample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = resnet34\n",
    "cut,lr_cut = model_meta[f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_base():\n",
    "    layers = cut_model(f(True), cut)\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dice(pred, targs):\n",
    "    pred = (pred>0).float()\n",
    "    return 2. * (pred*targs).sum() / (pred+targs).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StdUpsample(nn.Module):\n",
    "    def __init__(self, nin, nout):\n",
    "        super().__init__()\n",
    "        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)\n",
    "        self.bn = nn.BatchNorm2d(nout)\n",
    "        \n",
    "    def forward(self, x): return self.bn(F.relu(self.conv(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Upsample34(nn.Module):\n",
    "    def __init__(self, rn):\n",
    "        super().__init__()\n",
    "        self.rn = rn\n",
    "        self.features = nn.Sequential(\n",
    "            rn, nn.ReLU(),\n",
    "            StdUpsample(512,256),\n",
    "            StdUpsample(256,256),\n",
    "            StdUpsample(256,256),\n",
    "            StdUpsample(256,256),\n",
    "            nn.ConvTranspose2d(256, 1, 2, stride=2))\n",
    "        \n",
    "    def forward(self,x): return self.features(x)[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UpsampleModel():\n",
    "    def __init__(self,model,name='upsample'):\n",
    "        self.model,self.name = model,name\n",
    "\n",
    "    def get_layer_groups(self, precompute):\n",
    "        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))\n",
    "        return lgs + [children(self.model.features)[1:]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_base = get_base()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = to_gpu(Upsample34(m_base))\n",
    "models = UpsampleModel(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = ConvLearner(md, models)\n",
    "learn.opt_fn=optim.Adam\n",
    "learn.crit=nn.BCEWithLogitsLoss()\n",
    "learn.metrics=[accuracy_thresh(0.5),dice]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.freeze_to(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4667e9fa899453da7cbf0a0524fb426",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 86%|█████████████████████████████████████████████████████████████          | 55/64 [00:22<00:03,  2.46it/s, loss=3.21]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEOCAYAAAB4nTvgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3Xl4VOXd//H3NxthSQiBsIUlrLIosoRdFGtVXArWDVApiIhoqVufPg/t09/TltZqW9u6K4uiWBEUl6IgiDs7BJA1bIYtrGHfCUnu3x8z1hhDMkAmZybzeV3XXJlzn/vMfOcQ5pOz3cecc4iIiJQkyusCREQk9CksRESkVAoLEREplcJCRERKpbAQEZFSKSxERKRUQQ0LM+tjZuvNbJOZjSpm/j/N7Gv/Y4OZHSo0b7CZbfQ/BgezThERKZkF6zoLM4sGNgBXA9nAEmCgc27tWfr/AujgnBtqZslABpAOOGAp0Mk5dzAoxYqISImCuWXRBdjknMtyzuUCk4F+JfQfCLzpf34tMNs5d8AfELOBPkGsVUREShDMsEgFtheazva3/YCZNQaaAJ+dy7JmNtzMMvyP4WVStYiI/EBMEF/bimk72z6vAcBU51z+uSzrnBsLjAWoVauWS09PH3M+hYqIRKqlS5fuc86llNYvmGGRDTQsNN0A2HmWvgOAnxdZtneRZb8o6c3S0tLIyMg45yJFRCKZmW0NpF8wd0MtAVqYWRMzi8MXCNOKdjKzi4AawIJCzbOAa8yshpnVAK7xt4mIiAeCtmXhnMszs5H4vuSjgVecc2vMbDSQ4Zz7NjgGApNdodOynHMHzOyP+AIHYLRz7kCwahURkZIF7dTZ8paenu60G0pE5NyY2VLnXHpp/XQFt4iIlEphISIipVJYiIhIqSI+LJxzfLhyJyuzD3HoRK7X5YiIhKRgXmcRFvYePc3IScv/M50YH0OjmlVolFyFRslVaZRchcY1q9C8djVqJ1TCrLjrBUVEKraID4vkqnHMfLgX2/afYNsB32Pr/hOs23WU2Wv3cCb/u7PFEuNjaFEngZZ1qtG8dgItalejZZ0E6iQqRESkYov4sIiNjqJV3URa1U38wbz8AsfuI6fYsu84m/YeY8Oeo2zce4yZq3dz8MR3Q1clxsfQtn51LmlQnYtTq3NJanUaJ1chKkoBIiIVQ8SHRUmio4zUpMqkJlWmZ/Na35u379hpNu45xsa9R8ncdZQ1Ow/z6rwt5OYXAJBQKYa2qYlcXL867Rsl0bVJTVISKnnxMURELpguyitDuXkFbNjjC45VOw6zascRMncdITfPFyDNUqrSrWlNujatSbcmydROjPe0XhGRQC/KU1gE2Zn8AtbsPMKirP0szNrPki0HOXY6D4CmKVXp2qQmvVrUoleLWiTEx3pcrYhEGoVFiMrLL2DtriMszNrPoqwDLN58gKOn84iNNro1rclVrWpzVes6NEyu4nWpIhIBFBZhIi+/gGXbDvFp5h4+ydzDNznHAWhZpxpXta7Dj1vXpkPDGjpYLiJBobAIU1v2HeeTzD18mrmXJVsOkFfgaFCjMrd1asgtnVJpUENbHCJSdhQWFcDhk2f4fN1e3lmWzdxN+wC4rHktbu3UgGvb1iU+NtrjCkUk3CksKpjsgyd4Z+kO3l66neyDJ0mMj6Ff+1RuT2/IJQ2qe12eiIQphUUFVVDgWJi1n7cytvPR6t2cziugY6Mk7rmsKde2rUNMdMQP9yUi50BhEQEOnzzDu8uyeXX+FrbuP0FqUmWG9Eijf5eGJOo0XBEJgMIiguQXOD7N3MP4uZtZvPkAVeOiuS29IXf3TKNxzapelyciIUxhEaFW7zjMy3M388GKneQ7x7Vt6vLI1S25qG6C16WJSAhSWES4PUdOMXHBFibO38qx3Dxu6diAR65uSWpSZa9LE5EQorAQAA4ez+WFLzbx2oKtAAzu3pgHejenRtU4jysTkVCgsJDv2XHoJP+cvYF3l2VTNS6GEb2bcXfPNKrEaeBhkUimsJBibdhzlL/OXM8nmXuonVCJR69uye3pDTWciEiECjQsdFJ+hGlZJ4Hxg9OZOqI7jZKrMOrdVfQfu4BNe496XZqIhDCFRYRKT0vm7RHd+dut7di49xjXPT2Hf87ewOm8fK9LE5EQpLCIYGbGbekN+eTRK7jhkno8/elGrnt6Douy9ntdmoiEGIWFUKtaJZ4a0IHXhnbhTH4B/ccuZNQ7Kzl84ozXpYlIiFBYyH9c0TKFjx++gvuuaMrbS7O56h9f8tGqXV6XJSIhIKhhYWZ9zGy9mW0ys1Fn6XO7ma01szVmNqlQe76Zfe1/TAtmnfKdynHR/Pq61kwb2ZN61eO5/41lPPrW1xw5pa0MkUgWtFNnzSwa2ABcDWQDS4CBzrm1hfq0AN4CfuScO2hmtZ1ze/3zjjnnqgX6fjp1tuydyS/g2c828fznm6ibGM8/+7enS5Nkr8sSkTIUCqfOdgE2OeeynHO5wGSgX5E+9wLPO+cOAnwbFBIaYqOjePTqlrw9ojsx0Ub/sQt44qN15OYVeF2aiJSzYIZFKrC90HS2v62wlkBLM5tnZgvNrE+hefFmluFvv6m4NzCz4f4+GTk5OWVbvfxHx0Y1mPFgLwZ0bshLX37DTc/PY8MeXZchEkmCGRbFXRJcdJ9XDNAC6A0MBMabWZJ/XiP/ptEdwFNm1uwHL+bcWOdcunMuPSUlpewqlx+oWimGx29ux7ifpbPnyClufHYuE+ZtpqCgYowAICIlC2ZYZAMNC003AHYW0+ffzrkzzrnNwHp84YFzbqf/ZxbwBdAhiLVKgK5uU4eZD1/OZc1r8YcP1jJsYgaHTuR6XZaIBFkww2IJ0MLMmphZHDAAKHpW0/vAlQBmVgvfbqksM6thZpUKtfcE1iIhISWhEi8PTmd0v7bM2ZjDDc/M5evth7wuS0SCKGhh4ZzLA0YCs4BM4C3n3BozG21mff3dZgH7zWwt8DnwK+fcfqA1kGFmK/ztTxQ+i0q8Z2b8rHsaU0f0AOC2l+bz2vwtVJSBKUXk+zTqrFywQydy+eVbK/h03V5uaFePv9zSjmqVNPS5SDgIhVNnJUIkVYlj3M/S+Z8+rZi5ejd9n53Lut1HvC5LRMqQwkLKRFSUcX/vZkwa1pVjp/O46fl5vJ2xvfQFRSQsKCykTHVtWpPpD/aiY6Ma/GrqSv73vVW6iE+kAlBYSJlLSajE6/d0ZcQVzXhj0TbuHL+QnKOnvS5LRC6AwkKCIjrKGHVdK54Z2IFVOw7T97m5rMzW6bUi4UphIUHV99L6TB3RgygzbntpAe8tz/a6JBE5DwoLCbqLU6szbWRPOjRK4pEpK/jTh2vJy9dxDJFworCQclGzmu84xpAeaYyfu5khE5ZomBCRMKKwkHITGx3F7/u25a+3tmPx5gP0fW4e3+Qc87osEQmAwkLK3e3pDZl8XzeOn87jlhfns3TrAa9LEpFSKCzEEx0b1eDdB3pQo0ocd4xbxMzVu70uSURKoLAQzzSuWZWpI7rTpn4i97+xlNfmb/G6JBE5C4WFeKpmtUpMGtaNH7euw++mreHxjzJ1QyWREKSwEM9Vjovmpbs6cVe3Roz5MouHp3zN6bx8r8sSkUI0jrSEhOgo44/9LiY1qQp/mbmOvUdPMWZQOtUrx3pdmoigLQsJIWa+kWv/2f9Slm49yG0vzWfHoZNelyUiKCwkBP20QwNeu7sLuw6d4uYX5rFm52GvSxKJeAoLCUk9mtfi7fu7E2XG7S8t4MsNOV6XJBLRFBYSslrVTeS9B3rSqGZVhr66hClLtnldkkjEUlhISKtbPZ637utGj2Y1+Z93VvGPj9dTUe4bLxJOFBYS8hLiY3llSGduT2/AM59t4pdvrdDd90TKmU6dlbAQGx3FX25pR4MaVfjH7A3sPnKKlwZ1IjFep9aKlAdtWUjYMDMevKoFT952KYs3H+DmF+azed9xr8sSiQgKCwk7t3ZqwMR7urD/2Gn6PTeXr3SmlEjQKSwkLPVoVotpIy+jflJlhkxYzPg5WTrwLRJECgsJWw2Tq/DO/T24uk0d/jQ9k1++vYJTZzSmlEgwKCwkrFWtFMOLd3bi4R+34N1lO+g/diF7jpzyuiyRCkdhIWEvKsp4+McteemuTmzcc5SfPDuX5dsOel2WSIUS1LAwsz5mtt7MNpnZqLP0ud3M1prZGjObVKh9sJlt9D8GB7NOqRj6XFyXdx/oQaXYKPqPWciMVbu8LkmkwghaWJhZNPA8cB3QBhhoZm2K9GkB/Bro6ZxrCzzsb08Gfgd0BboAvzOzGsGqVSqOVnUTmfbzy7ikQXUefHM5H6/R7VpFykIwtyy6AJucc1nOuVxgMtCvSJ97geedcwcBnHN7/e3XArOdcwf882YDfYJYq1QgNarG8erdnWmbWp2fT1rG5+v3lr6QiJQomGGRCmwvNJ3tbyusJdDSzOaZ2UIz63MOy2Jmw80sw8wycnJ0rr18JyE+lol3d6FlnQRGvL6UeZv2eV2SSFgLZlhYMW1FT4SPAVoAvYGBwHgzSwpwWZxzY51z6c659JSUlAssVyqa6lVief2erjSpVZVhr2WwePMBr0sSCVvBDItsoGGh6QbAzmL6/Ns5d8Y5txlYjy88AllWpFTJVeN4/Z6u1E+K5+4Ji1mms6REzksww2IJ0MLMmphZHDAAmFakz/vAlQBmVgvfbqksYBZwjZnV8B/YvsbfJnLOUhIqMenebtRKqMTgVxazeofuvCdyroIWFs65PGAkvi/5TOAt59waMxttZn393WYB+81sLfA58Cvn3H7n3AHgj/gCZwkw2t8mcl7qJMYz6d5uJMbHctfLi1i3+4jXJYmEFaso4+mkp6e7jIwMr8uQELdt/wn6j11Abl4Bk4d3o0WdBK9LEvGUmS11zqWX1k9XcEtEaVSzCm8M60pUlDFw3CI27T3mdUkiYUFhIRGnaUo13ry3GwB3jFtIVo4CQ6Q0CguJSM1rV+PNe7tS4BwDxy1ki26iJFIihYVErBZ1EnhjWDfO5PsCY+t+BYbI2SgsJKJdVDeBN4Z15dSZfAaOXcj2Aye8LkkkJCksJOK1rpfIv4Z15cSZfAYoMESKpbAQAdrWr86/7unK0VNnGDhuITsOnfS6JJGQorAQ8bs4tTpvDOvG4ZNnGDh2IbsP6457It9SWIgUckmD6rx+T1cOHs/ljnEL2XtUgSECCguRH2jfMIkJd3dm95FT3DV+EQeO53pdkojnFBYixUhPS2b84HS27j/BoJcXcfjEGa9LEvGUwkLkLHo0q8WYQZ3YuOcYgycs5ugpBYZELoWFSAl6X1Sb5+7owOodhxn66hJO5OZ5XZKIJxQWIqW4pm1dnhrQnqVbD3LvxAxOncn3uiSRcqewEAnAje3q8+RtlzL/m/3c/6+lnM5TYEhkUViIBOjmjg34808v4fP1OTz45nLyCyrGvWBEAqGwEDkHA7s04v/d2IZZa/bw2PRMr8sRKTcxXhcgEm7uuawJ2QdP8Mq8zTSuWYXBPdK8Lkkk6BQWIufhtze0YfuBE/zhgzU0TK7Mj1rV8bokkaDSbiiR8xAdZTw9oANt6icyctJy1uw87HVJIkGlsBA5T1UrxfDy4M4kVY5l6KtL2HVYI9VKxaWwELkAdRLjeXlIZ46fzueeVzM4dloX7UnFpLAQuUCt6yXy/J0dWb/nKL+YtIy8/AKvSxIpcwGFhZk9ZGaJ5vOymS0zs2uCXZxIuLiiZQqj+7Xl8/U5jP5wLc7pGgypWALdshjqnDsCXAOkAHcDTwStKpEwdGfXxgy/vCkTF2xl/JzNXpcjUqYCPXXW/D+vByY451aYmZW0gEgkGtWnFTsOneSxGZlERxlDL2vidUkiZSLQsFhqZh8DTYBfm1kCoB2zIkVERRlP9W9PQYFj9IdrKXCOYb2ael2WyAULdDfUPcAooLNz7gQQi29XlIgUERsdxTMDO3D9JXX50/RMxs/J8rokkQsWaFh0B9Y75w6Z2V3Ab4FSr0Iysz5mtt7MNpnZqGLmDzGzHDP72v8YVmhefqH2aYF+IJFQEBsdxdMDOnDDJfX40/RMxn2lwJDwFuhuqBeBS83sUuC/gZeBicAVZ1vAzKKB54GrgWxgiZlNc86tLdJ1inNuZDEvcdI51z7A+kRCTmx0FE8NaA8Gj83IxOEYfnkzr8sSOS+BhkWec86ZWT/gaefcy2Y2uJRlugCbnHNZAGY2GegHFA0LkQorNjqKp/u3x4A/z1hHgYMRVygwJPwEuhvqqJn9GhgETPdvNcSWskwqsL3QdLa/rahbzGylmU01s4aF2uPNLMPMFprZTQHWKRJyYqKjeKp/e35yaX2e+GgdL37xjdcliZyzQMOiP3Aa3/UWu/F96f+tlGWKO7W26JVKHwBpzrl2wCfAa4XmNXLOpQN3AE+Z2Q/+HDOz4f5AycjJyQnwo4iUv5joKP55+6X0vbQ+f5m5jlHvrCTn6GmvyxIJWEBh4Q+IN4DqZnYjcMo5N7GUxbKBwlsKDYCdRV53v3Pu2/8x44BOhebt9P/MAr4AOhRT11jnXLpzLj0lJSWQjyLimZjoKP5x+6UMv7wpU5dm0/tvn/Pspxs5matbtEroC3S4j9uBxcBtwO3AIjO7tZTFlgAtzKyJmcUBA4DvndVkZvUKTfYFMv3tNcyskv95LaAnOtYhFUBMdBS/ub41sx+9gl4tUvj77A1c+eQXTF2aTYFu0yohzAIZw8bMVgBXO+f2+qdTgE+cc5eWstz1wFNANPCKc+4xMxsNZDjnppnZ4/hCIg84ANzvnFtnZj2AMfgu/IsCnnLOvVzSe6Wnp7uMjIxSP4tIKFm8+QCPTV/LiuzDtKmXyG9vaE2P5rW8LksiiJkt9e/yL7lfgGGxyjl3SaHpKGBF4TavKSwkXBUUOD5YuZO/zlzPjkMnuapVbUbfdDGpSZW9Lk0iQKBhEegB7plmNst/Ed0QYDow40IKFBGfqCijX/tUPv3lFYy6rhULs/Zz/dNzmL12j9elifxHQFsWAGZ2C75jBwZ85Zx7L5iFnSttWUhFsWXfcUa+uYzVO45wd880Rl3Xikox0V6XJRVUme6GCgcKC6lITufl8/iMdbw6fwsXpyby3MCOpNWq6nVZUgGVyW4oMztqZkeKeRw1syNlV66IFFYpJprf923LmEGd2H7gJDc+O5dpK3aWvqBIkJQYFs65BOdcYjGPBOdcYnkVKRKprm1blxkP9eKiugk8+OZyRr2zUtdliCd0D26REJeaVJnJw7vxQO9mTF6ynX7Pz2XDnqNelyURRmEhEgZio6P47z6tmDi0CweO59L3ublMXrxN9/qWcqOwEAkjl7dMYcZDvejUuAaj3l3Fg5O/5uipM16XJRFAYSESZmonxDNxaFd+de1FzFi1ixuemcvK7ENelyUVnMJCJAxFRxk/v7I5U4Z3Iy+/gFtenM/4OVnaLSVBo7AQCWPpacnMeKgXV15Umz9Nz+Se1zI4cDzX67KkAlJYiIS5pCpxjBnUiT/0bcvcjfv4ybNz2br/uNdlSQWjsBCpAMyMwT3SmHp/d07k5tF/zEI271NgSNlRWIhUIO0aJDHp3m7k5hfQf8wCNu095nVJUkEoLEQqmNb1Epk8vBsFDgaMXagL+KRMKCxEKqCWdRKYPLwbUQYDxy4kc5eGcpMLo7AQqaCa167G5OHdiI2O4o5xC1mz87DXJUkYU1iIVGBNU6ox5b5uVI6N5o5xi1iVrcCQ86OwEKngGtesypT7upMQH8Md4xeyfNtBr0uSMKSwEIkADZOrMOW+7tSoEsdd4xex4Jv9XpckYUZhIRIhUpMq89Z93amfVJkhExbz2Trd41sCp7AQiSB1q8cz5b7utKyTwPCJS3X3PQmYwkIkwiRXjWPSvV3p2LgGD01ezqRF27wuScKAwkIkAiXExzJxaBd6t0zhN++tYuxX33hdkoQ4hYVIhIqPjWbMoHRuaFePP89Yx98/Xq8hzuWsYrwuQES8ExcTxTMDOpBQKYZnP9vE0VN5/N+NbYiKMq9LkxCjsBCJcNFRxuM3X0JCfAzj5mxm//Fc/nZrO+Jjo70uTUKIwkJEMDN+c31rkqtW4i8z17Hz0EnGDupEzWqVvC5NQoSOWYgI4AuM+3s344U7O7J6x2FuemEeGzVirfgFNSzMrI+ZrTezTWY2qpj5Q8wsx8y+9j+GFZo32Mw2+h+Dg1mniHzn+kvqMeW+7pzMLeDmF+czd+M+r0uSEBC0sDCzaOB54DqgDTDQzNoU03WKc669/zHev2wy8DugK9AF+J2Z1QhWrSLyfe0bJvHvkT1JTarM4AmLdS2GBHXLoguwyTmX5ZzLBSYD/QJc9lpgtnPugHPuIDAb6BOkOkWkGKlJlXl7RHd6tajFb95bxWPT15JfoFNrI1UwwyIV2F5oOtvfVtQtZrbSzKaaWcNzWdbMhptZhpll5OTklFXdIuKXEB/L+J+lM6RHGuPmbOa+15dyIjfP67LEA8EMi+JO1C76Z8kHQJpzrh3wCfDaOSyLc26scy7dOZeekpJyQcWKSPFioqP4fd+2/KFvWz5bt4e7xi/i0Ilcr8uSchbMsMgGGhaabgB8b9Qy59x+59xp/+Q4oFOgy4pI+RrcI40X7uzE6h1HuH3MAnYfPuV1SVKOghkWS4AWZtbEzOKAAcC0wh3MrF6hyb5Apv/5LOAaM6vhP7B9jb9NRDzU5+K6vDq0MzsPneLWl+azed9xr0uSchK0sHDO5QEj8X3JZwJvOefWmNloM+vr7/agma0xsxXAg8AQ/7IHgD/iC5wlwGh/m4h4rEezWrx5bzdO5OZz20vzWb1Dt2qNBFZRBg5LT093GRkZXpchEjG+yTnGz15ezOGTZxj3s3S6N6vpdUlyHsxsqXMuvbR+uoJbRM5Ls5RqTL2/O3WrxzN4wmI+XrPb65IkiBQWInLe6lWvzNv3dadNvURG/Gspb2VsL30hCUsKCxG5IDWqxvHGsK70bF6L/566kqc/2aj7YlRACgsRuWBVK8Xw8uDO3NKxAf/8ZAO/fGsFp/PyvS5LypCGKBeRMhEXE8WTt7UjrWYV/j57A9mHTjLmrk7UqBrndWlSBrRlISJlxsz4xVUteHpAe77efoibX9S1GBWFwkJEyly/9qlMGtaVQydy+ekL81i8WZdJhTuFhYgERXpaMu890JPkKnHcNX4R7y3P9rokuQAKCxEJmrRaVXn3gR50bJzEI1NW8NQnG3SmVJhSWIhIUCVViWPi0K7c0rEBT32ykd9PW0OB7osRdnQ2lIgE3bdnSiVXjWXcnM0cOZXHX29tR2y0/l4NFwoLESkXZsZvrm9N9cqxPPnxBo6eyuO5OzoQHxvtdWkSAMW6iJQbM2Pkj1owul9bPsncw9BXl3DstO68Fw4UFiJS7n7WPY1/3H4pizYf4M7xizh4XHfeC3UKCxHxxM0dG/DinR3J3HWE/mMXsOeI7rwXyhQWIuKZa9rW5dUhnck+eJLbXlrAtv0nvC5JzkJhISKe6tG8FpPu7caRU2e45aX5ZO464nVJUgyFhYh4rn3DJN66rzvRZtw+ZgGLsvZ7XZIUobAQkZDQsk4C7zzQg9oJlRj0iu68F2oUFiISMlKTKvP2iB60/vbOe0t0571QobAQkZCSXDWOScO6clmLFP77nZW88MUmjScVAhQWIhJyqlaKYfzP0ul7aX3+OnM9f/wwU+NJeUzDfYhISIqLieKp/u2pWS2OV+Zt5sDx0/z11kuJi9HfuF5QWIhIyIqKMv7vxjbUqlaJv81az87Dp3jujg7UToj3urSIo4gWkZBmZvz8yuY81b89K7MPceMzc1myRXfeK28KCxEJCzd1SOX9n/ekSlw0A8YuZPycLB34LkcKCxEJG63qJjLtF5dxVava/Gl6JiMnLdeoteVEYSEiYSUxPpYxgzox6rpWfLR6F/2em8vGPUe9LqvCC2pYmFkfM1tvZpvMbFQJ/W41M2dm6f7pNDM7aWZf+x8vBbNOEQkvZsaIK5rxxrBuHD55hn7Pz2Paip1el1WhBS0szCwaeB64DmgDDDSzNsX0SwAeBBYVmfWNc669/zEiWHWKSPjq3qwm0x/sRZt6iTz45nIefHO5hjoPkmBuWXQBNjnnspxzucBkoF8x/f4I/BXQv7CInLM6ifG8ObwbD13VgplrdnPV379k/JwszuQXeF1ahRLMsEgFCg/sku1v+w8z6wA0dM59WMzyTcxsuZl9aWa9insDMxtuZhlmlpGTk1NmhYtIeImNjuKRq1vy8cOXk55Wgz9Nz+TGZ+Zq9NoyFMywsGLa/nOem5lFAf8EfllMv11AI+dcB+BRYJKZJf7gxZwb65xLd86lp6SklFHZIhKu0mpVZcKQzowd1Iljp/PoP3Yhj0z5mr1HtePiQgUzLLKBhoWmGwCFj0AlABcDX5jZFqAbMM3M0p1zp51z+wGcc0uBb4CWQaxVRCoIM+OatnX55NErGHllc6av3MVVT37JhHmbydf4UuctmGGxBGhhZk3MLA4YAEz7dqZz7rBzrpZzLs05lwYsBPo65zLMLMV/gBwzawq0ALKCWKuIVDCV46L5r2svYubDvWjfKIk/fLCWgeMWkn1Qt249H0ELC+dcHjASmAVkAm8559aY2Wgz61vK4pcDK81sBTAVGOGc0/X9InLOmqZUY+LQLjx526Ws3XmE656aw7vLsivM1d87Dp3kZG5+0N/HKsoKS09PdxkZGV6XISIhbPuBEzwy5Wsyth7khkvq8dhPLyapSpzXZV2QQS8vIufoaT56qBdmxR0qLpmZLXXOpZfWT1dwi0jEaJhchSn3dedX117ErDW7ufapr5i7cZ/XZZ23ldmHmLNxH/3ap55XUJwLhYWIRJToKN8otu//vCcJ8bHc9fIi/vDBGk6dCf6unLL2wuffkBAfw13dGgX9vRQWIhKRLk6tzoe/uIwhPdKYMG8LNzwzhy/W7/W6rIBt2nuMWWt3M7h7GgnxsUF/P4WFiESs+Nhoft+3La8N7UJegWPIhCUMfmVxWAxM+NKX31ApJoq7e6aVy/spLEQk4l3RMoXZj1zBb29ozbJtB+nz9BxHekbHAAALjklEQVT+3/ur2X/stNelFWvHoZO8v3wHAzo3oma1SuXyngoLERF89/we1qspX/7qSu7s2ohJi7fR+8kvGPvVN5zOC63jGeO+8l12du/lTcvtPRUWIiKFJFeNY3S/i5n5UC/SG9fgzzPWcfU/vuL95TvIzfN+cMJ9x07z5uJt/LRDKqlJlcvtfRUWIiLFaFEngQl3d2Hi0C5Ujo3m4Slf0+OJT/nbrHUBXwWeX+BYuvUg4+dksWnvsTKpa8K8zeTmFzCid7Myeb1A6aI8EZFSFBQ45mzax+sLtvLZuj0A/KhVbe7q1pjLW6QQFfXdNQ77j53mq405fL4uh6825nDoxBnAd8ruwC4NefjHLal1nscZjpw6Q88nPqNXi1q8cGenC/9gBH5RXkyZvJuISAUWFWVc0TKFK1qmsOPQSd5ctI3JS7bxSeZeGteswsAujTh9poDP1+9lRfYhnINa1eK4qlUdrmyVQtv61ZkwbzNvLNrGe8t2cH/vZtxzWVMqx0WfUx3/WriVo6fyeKB38yB90rPTloWIyHnIzStg5prd/GvhVhZvPoAZtG+YRO+WtbmyVQoX16/+vS0OgG9yjvGXj9bx8do91E2M59FrWnJLxwZER5V+9fWpM/lc9pfPaF0vkdfv6VpmnyPQLQuFhYjIBdq2/wRVK0UHfBrr4s0HeGxGJiu2H6JV3QR+fX1rLm9Rq8QhOyYu2ML//XsNk4d3o1vTmmVUucaGEhEpN41qVjmn6x26NEnm/Qd68NwdHTiRm8/gVxbT7/l5fLhyJ3nF3A72TH4BY77MomOjJLo2SS7L0gOmYxYiIh4wM25sV5+r29Rh6tJsxs/ZzMhJy2mYXJlhlzXltvQGVInzfUVP+3onOw6dZHS/tkEfMPCs9Wo3lIiI9/ILHLPX7mHMV9+wfNshalSJZVD3NAZ1a8zAcQuJibLzHoa8JDobSkQkjERHGX0ursu1beuQsfUgY77M4plPN/LC55vIK3A8PaC9Z1sVoLAQEQkpZkbntGQ6pyWzae8xxs/JYt+xXG64pJ6ndSksRERCVPPa1XjilnZelwHobCgREQmAwkJEREqlsBARkVIpLEREpFQKCxERKZXCQkRESqWwEBGRUiksRESkVBVmbCgzywG2AtWBw8V0Ka69aFvR6VrAvjIsszhnq7csly2tX0nzA1lvxbVpXQY273zaImVdltQnUv6fn8ty5/u72cI5V73UV3fOVagHMDbQ9qJtxUxneFVvWS5bWr+S5gey3rQug7sui7ZFyrosqU+k/D8/l+XO93cz0PeoiLuhPjiH9qJtZ1s2mC7kPQNdtrR+Jc0PZL0V16Z1Gdi8C2kLplBYlyX1iZT/5+ey3Pn+bgb0HhVmN1QwmFmGC2DoXimd1mXZ0bosW1qfgamIWxZlaazXBVQgWpdlR+uybGl9BkBbFiIiUiptWYiISKkUFiIiUiqFhYiIlEphISIipVJYXAAzq2pmS83sRq9rCWdm1trMXjKzqWZ2v9f1hDMzu8nMxpnZv83sGq/rCWdm1tTMXjazqV7XEgoiMizM7BUz22tmq4u09zGz9Wa2ycxGBfBS/wO8FZwqw0NZrEvnXKZzbgRwOxCx57uX0bp83zl3LzAE6B/EckNaGa3LLOfcPcGtNHxE5KmzZnY5cAyY6Jy72N8WDWwArgaygSXAQCAaeLzISwwF2uEbUyYe2Oec+7B8qg8tZbEunXN7zawvMAp4zjk3qbzqDyVltS79y/0deMM5t6ycyg8pZbwupzrnbi2v2kNVjNcFeME595WZpRVp7gJscs5lAZjZZKCfc+5x4Ae7mczsSqAq0AY4aWYznHMFQS08BJXFuvS/zjRgmplNByIyLMro99KAJ4CPIjUooOx+L+U7ERkWZ5EKbC80nQ10PVtn59z/ApjZEHxbFhEXFCU4p3VpZr2Bm4FKwIygVhZ+zmldAr8AfgxUN7PmzrmXgllcmDnX38uawGNABzP7tT9UIpbC4jtWTFup++icc6+WfSlh75zWpXPuC+CLYBUT5s51XT4DPBO8csLaua7L/cCI4JUTXiLyAPdZZAMNC003AHZ6VEu407osO1qXZUfr8gIoLL6zBGhhZk3MLA4YAEzzuKZwpXVZdrQuy47W5QWIyLAwszeBBcBFZpZtZvc45/KAkcAsIBN4yzm3xss6w4HWZdnRuiw7WpdlLyJPnRURkXMTkVsWIiJybhQWIiJSKoWFiIiUSmEhIiKlUliIiEipFBYiIlIqhYV4xsyOlcN79A1wuPmyfM/eZtbjPJbrYGbj/c+HmNlzZV/duTOztKJDfRfTJ8XMZpZXTVL+FBYS9vxDTxfLOTfNOfdEEN6zpHHVegPnHBbAb4Bnz6sgjznncoBdZtbT61okOBQWEhLM7FdmtsTMVprZHwq1v++/G+EaMxteqP2YmY02s0VAdzPbYmZ/MLNlZrbKzFr5+/3nL3Qze9XMnjGz+WaWZWa3+tujzOwF/3t8aGYzvp1XpMYvzOzPZvYl8JCZ/cTMFpnZcjP7xMzq+IfFHgE8YmZfm1kv/1/d7/g/35LivlDNLAFo55xbUcy8xmb2qX/dfGpmjfztzcxsof81Rxe3pWa+uzlON7MVZrbazPr72zv718MKM1tsZgn+LYg5/nW4rLitIzOLNrO/Ffq3uq/Q7PeBO4v9B5bw55zTQw9PHsAx/89rgLH4RgWNAj4ELvfPS/b/rAysBmr6px1we6HX2gL8wv/8AWC8//kQfDdUAngVeNv/Hm3w3dsA4FZ8Q6NHAXWBg8CtxdT7BfBCoekafDcKwjDg7/7nvwf+q1C/ScBl/ueNgMxiXvtK4J1C04Xr/gAY7H8+FHjf//xDYKD/+Yhv12eR170FGFdoujoQB2QBnf1tifhGoK4CxPvbWgAZ/udpwGr/8+HAb/3PKwEZQBP/dCqwyuvfKz2C89AQ5RIKrvE/lvunq+H7svoKeNDMfupvb+hv3w/kA+8UeZ13/T+X4rs/RnHed757j6w1szr+tsuAt/3tu83s8xJqnVLoeQNgipnVw/cFvPksy/wYaGP2nxGyE80swTl3tFCfekDOWZbvXujzvA78tVD7Tf7nk4Ani1l2FfCkmf0F+NA5N8fMLgF2OeeWADjnjoBvKwR4zsza41u/LYt5vWuAdoW2vKrj+zfZDOwF6p/lM0iYU1hIKDDgcefcmO81+m6K9GOgu3PuhJl9ge82tgCnnHP5RV7ntP9nPmf/3T5d6LkV+RmI44WePwv8wzk3zV/r78+yTBS+z3CyhNc9yXefrTQBD+jmnNtgZp2A64HHzexjfLuLinuNR4A9wKX+mk8V08fwbcHNKmZePL7PIRWQjllIKJgFDDWzagBmlmpmtfH91XrQHxStgG5Bev+5wC3+Yxd18B2gDkR1YIf/+eBC7UeBhELTH+Mb7RQA/1/uRWUCzc/yPvPxDacNvmMCc/3PF+LbzUSh+d9jZvWBE865f+Hb8ugIrAPqm1lnf58E/wH76vi2OAqAQfjuTV3ULOB+M4v1L9vSv0UCvi2REs+akvClsBDPOec+xrcbZYGZrQKm4vuynQnEmNlK4I/4vhyD4R18N8ZZDYwBFgGHA1ju98DbZjYH2Feo/QPgp98e4AYeBNL9B4TXUszd15xz6/DdCjWh6Dz/8nf718Mg4CF/+8PAo2a2GN9urOJqvgRYbGZfA/8L/Mk5lwv0B541sxXAbHxbBS8Ag81sIb4v/uPFvN54YC2wzH867Ri+24q7EphezDJSAWiIchHAzKo5546Z777Li4Gezrnd5VzDI8BR59z4APtXAU4655yZDcB3sLtfUIssuZ6vgH7OuYNe1SDBo2MWIj4fmlkSvgPVfyzvoPB7EbjtHPp3wndA2oBD+M6U8oSZpeA7fqOgqKC0ZSEiIqXSMQsRESmVwkJEREqlsBARkVIpLEREpFQKCxERKdX/B8JmHe9veFf/AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.lr_find()\n",
    "learn.sched.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=4e-2\n",
    "wd=1e-7\n",
    "lrs = np.array([lr/100,lr/10,lr])/2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dfb71a3eb494433bd0f58647124a177",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/64 [00:00<?, ?it/s]\n",
      "epoch      trn_loss   val_loss   <lambda>   dice           \n",
      "    0      0.216882   0.133512   0.938017   0.855221  \n",
      "    1      0.169544   0.115158   0.946518   0.878381       \n",
      "    2      0.153114   0.099104   0.957748   0.903353       \n",
      "    3      0.144105   0.093337   0.964404   0.915084       \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.09333742126112893, 0.9644036065964472, 0.9150839788573129]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lr,1, wds=wd, cycle_len=4,use_clr=(20,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.bn_freeze(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8176783850a4cf48c3ba20494fac24e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice           \n",
      "    0      0.174897   0.061603   0.976321   0.94382   \n",
      "    1      0.122911   0.053625   0.982206   0.957624       \n",
      "    2      0.106837   0.046653   0.985577   0.965792       \n",
      "    3      0.099075   0.042291   0.986519   0.968925        \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.042291240323157536, 0.986519161670927, 0.9689251193924556]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('128')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(md.val_dl))\n",
    "py = to_np(learn.model(V(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABA5JREFUeJzt3dtt02AAhuG06hSZIkugTMCUTFCxRKfoGISrSH0L6dGJf9vPc4ki4V709edD4O50Ou0Azu7nPgBgLKIAhCgAIQpAiAIQogCEKAAhCkCIAhAPcx/Abrfb/bj/6bVKuLLff37dfeRzlgIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhAPcx8Ay/b4/LQ77g///Nklrz/LeCwFICwFLnrrjP+Vz136rPUwFksBCEthgz5zZr8F62EsorAho8XgLS+PVSBuy+UDEKKwEUtaCcxLFIBwT2Hl1rAQzj+Dewu3YSkAYSms1BoWwmsWw21YCivz+Py0yiC8tPafb26iAIQosEhbWERzEQUgRIFFsximJwor4peDKYgCEKKwAia0lTQlUQBCFIDwmvOCmczlNehpWApAiMJCWQlciygAIQqsjke03yMKQIgCEB5JLoxZzLVZCkCIwsIc9wcv53BVogCEewoL454C12YpsFreV/gaUQDC5cNCOONxK5YCEKLA6llZnyMKQIgCEKKwAOYvtyQKQIgCEKIAhJeXBuZewnT88+8fZykAIQpAiAKb4puT7xMFIEQBCE8fBmTeMidLAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIDw3YeB+M4DI7AUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFNgkL4pdJgpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCiwScf9Ye5DGJYoACEKQIgCEKIAhCgAIQpAiAIQojCQ4/7g+TmzEwUgRAEIUQDiYe4DgFtyz+Z9osAmiMHHuXwAQhSAEAUgRAEIUQDC04cBne+U++/Sv8cTh68RBVZDBKbh8gEIS2Fg/zvzvXVJ8d6Zci2XIxbBdVkKQFgKC/Odm5BTnmGnXB3O/GMRhYWa+xdp7r+f63H5AIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAcXc6neY+BmAglgIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgA8RefEJKlJ3LAMQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(py[0]>0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABA5JREFUeJzt3dFN22AARtFQMQVTsETFBJ2yE6AuwRSM0fQJNbdqCAE79m+f81hVyCDl+rOdwN3xeDwAvPm29AEA6yIKQIgCEKIAhCgAIQpAiAIQogCEKABxv/QBHA6Hw/dvP7ytEmb26/fPu4/8P0sBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAGI+6UPgHV5fn2Z9es/PTzO+vX5OksBCEthR+ZeAVMeg0WxHEsBCFFglZ5fX1axbPZIFHZi1BfYqMc9MlEAQhSAEAUgPJLcuC1ck799Dx5T3oalAIQoMAyPKW/D5cNGefHwWZYCEKIAhCgAIQobs4ebcVv//pYmCkCIAhCiwJD2cJm0FFEAQhQ2xJmTKYgCEKIAhM8+DMBlwXX+/Xn5yPV1LAUgLAWGdvoLWM4tqtN/txousxSAsBRW7NK9hLeznnsOTEkUBiYGf/lZTMflAxCWwoqd3hRzJuRWLAUgLIVB/O9RmvVQHjdOw1IYlCAwF1EAQhTYDL94ZRqiAIQbjYNxJrzMH6T9GksBCEthEBbC9SyGz7EU2DxBvY4oACEKA3Cm45ZEAQhRAMLThxVz2TAdTyI+zlIAQhSAEAUgRIFd8UnKy0QBCFEAQhSAEAUgvHlphdwIY0mWAhCiAIQosEsu0c4TBSBEAQhRAEIUgPA+hRVx84s1sBSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFFbk6eHx8PTwuPRhsHOiAIQoACEKQIgCu+TezXmiAIQoACEKK+TRJEsSBSBEAQhRAEIUgLhf+gA47/Rm4/Pry4JHsh1u4F5mKQBhKQzivTOcFfE+6+A6lgIQlsIGzH0mXOMScfafjyhwkRfgvrh8AEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAuDsej0sfA7AilgIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgA8QfWQY8SR/FLYgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(y[0]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## U-net (ish)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveFeatures():\n",
    "    features=None\n",
    "    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)\n",
    "    def hook_fn(self, module, input, output): self.features = output\n",
    "    def remove(self): self.hook.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UnetBlock(nn.Module):\n",
    "    def __init__(self, up_in, x_in, n_out):\n",
    "        super().__init__()\n",
    "        up_out = x_out = n_out//2\n",
    "        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)\n",
    "        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)\n",
    "        self.bn = nn.BatchNorm2d(n_out)\n",
    "        \n",
    "    def forward(self, up_p, x_p):\n",
    "        up_p = self.tr_conv(up_p)\n",
    "        x_p = self.x_conv(x_p)\n",
    "        cat_p = torch.cat([up_p,x_p], dim=1)\n",
    "        return self.bn(F.relu(cat_p))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Unet34(nn.Module):\n",
    "    def __init__(self, rn):\n",
    "        super().__init__()\n",
    "        self.rn = rn\n",
    "        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]\n",
    "        self.up1 = UnetBlock(512,256,256)\n",
    "        self.up2 = UnetBlock(256,128,256)\n",
    "        self.up3 = UnetBlock(256,64,256)\n",
    "        self.up4 = UnetBlock(256,64,256)\n",
    "        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.relu(self.rn(x))\n",
    "        x = self.up1(x, self.sfs[3].features)\n",
    "        x = self.up2(x, self.sfs[2].features)\n",
    "        x = self.up3(x, self.sfs[1].features)\n",
    "        x = self.up4(x, self.sfs[0].features)\n",
    "        x = self.up5(x)\n",
    "        return x[:,0]\n",
    "    \n",
    "    def close(self):\n",
    "        for sf in self.sfs: sf.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UnetModel():\n",
    "    def __init__(self,model,name='unet'):\n",
    "        self.model,self.name = model,name\n",
    "\n",
    "    def get_layer_groups(self, precompute):\n",
    "        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))\n",
    "        return lgs + [children(self.model)[1:]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_base = get_base()\n",
    "m = to_gpu(Unet34(m_base))\n",
    "models = UnetModel(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = ConvLearner(md, models)\n",
    "learn.opt_fn=optim.Adam\n",
    "learn.crit=nn.BCEWithLogitsLoss()\n",
    "learn.metrics=[accuracy_thresh(0.5),dice]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('Conv2d-1',\n",
       "              OrderedDict([('input_shape', [-1, 3, 128, 128]),\n",
       "                           ('output_shape', [-1, 64, 64, 64]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 9408)])),\n",
       "             ('BatchNorm2d-2',\n",
       "              OrderedDict([('input_shape', [-1, 64, 64, 64]),\n",
       "                           ('output_shape', [-1, 64, 64, 64]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-3',\n",
       "              OrderedDict([('input_shape', [-1, 64, 64, 64]),\n",
       "                           ('output_shape', [-1, 64, 64, 64]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('MaxPool2d-4',\n",
       "              OrderedDict([('input_shape', [-1, 64, 64, 64]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-5',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-6',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-7',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-8',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-9',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-10',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-11',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-12',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-13',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-14',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-15',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-16',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-17',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-18',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-19',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-20',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-21',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-22',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 36864)])),\n",
       "             ('BatchNorm2d-23',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 128)])),\n",
       "             ('ReLU-24',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-25',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 64, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-26',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 73728)])),\n",
       "             ('BatchNorm2d-27',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-28',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-29',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-30',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('Conv2d-31',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 8192)])),\n",
       "             ('BatchNorm2d-32',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-33',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-34',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-35',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-36',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-37',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-38',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-39',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-40',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-41',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-42',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-43',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-44',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-45',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-46',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-47',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-48',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-49',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-50',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-51',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-52',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 147456)])),\n",
       "             ('BatchNorm2d-53',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 256)])),\n",
       "             ('ReLU-54',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-55',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-56',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 294912)])),\n",
       "             ('BatchNorm2d-57',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-58',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-59',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-60',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('Conv2d-61',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 32768)])),\n",
       "             ('BatchNorm2d-62',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-63',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-64',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-65',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-66',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-67',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-68',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-69',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-70',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-71',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-72',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-73',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-74',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-75',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-76',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-77',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-78',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-79',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-80',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-81',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-82',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-83',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-84',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-85',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-86',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-87',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-88',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-89',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-90',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-91',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-92',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-93',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-94',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-95',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-96',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 589824)])),\n",
       "             ('BatchNorm2d-97',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('ReLU-98',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-99',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-100',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1179648)])),\n",
       "             ('BatchNorm2d-101',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-102',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-103',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 2359296)])),\n",
       "             ('BatchNorm2d-104',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('Conv2d-105',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 131072)])),\n",
       "             ('BatchNorm2d-106',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-107',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-108',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-109',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 2359296)])),\n",
       "             ('BatchNorm2d-110',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-111',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-112',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 2359296)])),\n",
       "             ('BatchNorm2d-113',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-114',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-115',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-116',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 2359296)])),\n",
       "             ('BatchNorm2d-117',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-118',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('Conv2d-119',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 2359296)])),\n",
       "             ('BatchNorm2d-120',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('trainable', False),\n",
       "                           ('nb_params', 1024)])),\n",
       "             ('ReLU-121',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('BasicBlock-122',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 512, 4, 4]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('ConvTranspose2d-123',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 128, 8, 8]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 262272)])),\n",
       "             ('Conv2d-124',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 128, 8, 8]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 32896)])),\n",
       "             ('BatchNorm2d-125',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('UnetBlock-126',\n",
       "              OrderedDict([('input_shape', [-1, 512, 4, 4]),\n",
       "                           ('output_shape', [-1, 256, 8, 8]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('ConvTranspose2d-127',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 131200)])),\n",
       "             ('Conv2d-128',\n",
       "              OrderedDict([('input_shape', [-1, 128, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 16, 16]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 16512)])),\n",
       "             ('BatchNorm2d-129',\n",
       "              OrderedDict([('input_shape', [-1, 256, 16, 16]),\n",
       "                           ('output_shape', [-1, 256, 16, 16]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('UnetBlock-130',\n",
       "              OrderedDict([('input_shape', [-1, 256, 8, 8]),\n",
       "                           ('output_shape', [-1, 256, 16, 16]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('ConvTranspose2d-131',\n",
       "              OrderedDict([('input_shape', [-1, 256, 16, 16]),\n",
       "                           ('output_shape', [-1, 128, 32, 32]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 131200)])),\n",
       "             ('Conv2d-132',\n",
       "              OrderedDict([('input_shape', [-1, 64, 32, 32]),\n",
       "                           ('output_shape', [-1, 128, 32, 32]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 8320)])),\n",
       "             ('BatchNorm2d-133',\n",
       "              OrderedDict([('input_shape', [-1, 256, 32, 32]),\n",
       "                           ('output_shape', [-1, 256, 32, 32]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('UnetBlock-134',\n",
       "              OrderedDict([('input_shape', [-1, 256, 16, 16]),\n",
       "                           ('output_shape', [-1, 256, 32, 32]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('ConvTranspose2d-135',\n",
       "              OrderedDict([('input_shape', [-1, 256, 32, 32]),\n",
       "                           ('output_shape', [-1, 128, 64, 64]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 131200)])),\n",
       "             ('Conv2d-136',\n",
       "              OrderedDict([('input_shape', [-1, 64, 64, 64]),\n",
       "                           ('output_shape', [-1, 128, 64, 64]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 8320)])),\n",
       "             ('BatchNorm2d-137',\n",
       "              OrderedDict([('input_shape', [-1, 256, 64, 64]),\n",
       "                           ('output_shape', [-1, 256, 64, 64]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 512)])),\n",
       "             ('UnetBlock-138',\n",
       "              OrderedDict([('input_shape', [-1, 256, 32, 32]),\n",
       "                           ('output_shape', [-1, 256, 64, 64]),\n",
       "                           ('nb_params', 0)])),\n",
       "             ('ConvTranspose2d-139',\n",
       "              OrderedDict([('input_shape', [-1, 256, 64, 64]),\n",
       "                           ('output_shape', [-1, 1, 128, 128]),\n",
       "                           ('trainable', True),\n",
       "                           ('nb_params', 1025)]))])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[torch.Size([3, 64, 64, 64]),\n",
       " torch.Size([3, 64, 32, 32]),\n",
       " torch.Size([3, 128, 16, 16]),\n",
       " torch.Size([3, 256, 8, 8])]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[o.features.size() for o in m.sfs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.freeze_to(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2661ed6f2c049098c0b37f13ac2e03e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                           | 0/64 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception in thread Thread-14:\n",
      "Traceback (most recent call last):\n",
      "  File \"C:\\Users\\j\\Anaconda3\\envs\\fastai\\lib\\threading.py\", line 916, in _bootstrap_inner\n",
      "    self.run()\n",
      "  File \"C:\\Users\\j\\Anaconda3\\envs\\fastai\\lib\\site-packages\\tqdm\\_tqdm.py\", line 144, in run\n",
      "    for instance in self.tqdm_cls._instances:\n",
      "  File \"C:\\Users\\j\\Anaconda3\\envs\\fastai\\lib\\_weakrefset.py\", line 60, in __iter__\n",
      "    for itemref in self.data:\n",
      "RuntimeError: Set changed size during iteration\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 92%|█████████████████████████████████████████████████████████████████▍     | 59/64 [00:22<00:01,  2.68it/s, loss=2.45]\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.lr_find()\n",
    "learn.sched.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=4e-2\n",
    "wd=1e-7\n",
    "\n",
    "lrs = np.array([lr/100,lr/10,lr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dce9710ec3314546b8d1dfc8e1d250f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=8), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice           \n",
      "    0      0.12936    0.03934    0.988571   0.971385  \n",
      "    1      0.098401   0.039252   0.990438   0.974921        \n",
      "    2      0.087789   0.02539    0.990961   0.978927        \n",
      "    3      0.082625   0.027984   0.988483   0.975948        \n",
      "    4      0.079509   0.025003   0.99171    0.981221        \n",
      "    5      0.076984   0.022514   0.992462   0.981881        \n",
      "    6      0.076822   0.023203   0.992484   0.982321        \n",
      "    7      0.075488   0.021956   0.992327   0.982704        \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.021955982234979434, 0.9923273126284281, 0.9827044502137199]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lr,1,wds=wd,cycle_len=8,use_clr=(5,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('128urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('128urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.bn_freeze(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86aa94831fa840b7ae471989af971594",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/64 [00:00<?, ?it/s]\n",
      "epoch      trn_loss   val_loss   <lambda>   dice            \n",
      "    0      0.073786   0.023418   0.99297    0.98283   \n",
      "    1      0.073561   0.020853   0.992142   0.982725        \n",
      "    2      0.075227   0.023357   0.991076   0.980879        \n",
      "    3      0.074245   0.02352    0.993108   0.983659        \n",
      "    4      0.073434   0.021508   0.993024   0.983609        \n",
      "    5      0.073092   0.020956   0.993188   0.983333        \n",
      "    6      0.073617   0.019666   0.993035   0.984102        \n",
      "    7      0.072786   0.019844   0.993196   0.98435         \n",
      "    8      0.072256   0.018479   0.993282   0.984277        \n",
      "    9      0.072052   0.019479   0.993164   0.984147        \n",
      "    10     0.071361   0.019402   0.993344   0.984541        \n",
      "    11     0.070969   0.018904   0.993139   0.984499        \n",
      "    12     0.071588   0.018027   0.9935     0.984543        \n",
      "    13     0.070709   0.018345   0.993491   0.98489         \n",
      "    14     0.072238   0.019096   0.993594   0.984825        \n",
      "    15     0.071407   0.018967   0.993446   0.984919        \n",
      "    16     0.071047   0.01966    0.993366   0.984952        \n",
      "    17     0.072024   0.018133   0.993505   0.98497         \n",
      "    18     0.071517   0.018464   0.993602   0.985192        \n",
      "    19     0.070109   0.018337   0.993614   0.9852          \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.018336569653853538, 0.9936137114252362, 0.9852004420189631]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lrs/4, 1, wds=wd, cycle_len=20,use_clr=(20,10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('128urn-0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('128urn-0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(md.val_dl))\n",
    "py = to_np(learn.model(V(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABA5JREFUeJzt3dFN21AAhtFQMQVTsETFBJ2yE6AuwRSM0fQpEl9VwNBg32uf8wZCwqDw+Xewkpvz+XwCuPi29QEAYxEFIEQBCFEAQhSAEAUgRAEIUQBCFIC43foATqfT6fu3H26rhC/26/fPmyVfZykAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKABxu/UBMJbH56fVvtfD3f1q34vlLAUgLIUDWnMNvOWt47AitmMpAGEpHMgoC2GJl8dqNazLUjiImYLAtkQBCJcPO7eHhXD5GVxGrMNSAEIUmMbj89Muls/oRAEIUQBCFHZqz1N7rz/XKEQBCFFgSnteQlsTBSDcvLQzzp78L0sBCFEAQhSYmiccr08UgBAFdsFauB5R2JGj/2G4lLgOUQDCfQoTeO3sd3nREWfH+tfvwwu0LGcpAGEpTMxC4CtYCkBYCgOzBNiCpTCwh7t7T5CxOlEAQhQmYC2wJlEAQhQm43mGj/P7+hj/fZiEBzZrsRSAsBTYLevqcywFICyFybjLcTlvYf85lgIQlsIkLATWYimwe4L6MaIAhCgAIQpAiMIEXBOzJlHgELwnxHKiAIT7FAbmzHZ97nJ8n6UAhCgAIQpAiAIQogCEKAAhCkCIAhCiwCG5Mex17mgckAcsW7IUgBAFIEQBCFEAQhSAEAUgRAEIUQDCzUsDcdMSI7AUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQhYE83N1741M2JwpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiwCF52bvXicKAvFYjWxIFIEQBCFEAQhSAEAUgbrc+AFiT/+q8TxQG9vIB/Pj8lM9dPmYZMVjO5QMQlsIk/j7TvXbmO/KCsAauw1IAwlLYmbfOlrOuCAtgXaJwIP64WMLlAxCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAxM35fN76GICBWApAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAxB8Zg4ZPO/0D6QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(py[0]>0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD8CAYAAAB+fLH0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABA5JREFUeJzt3dFN22AARtFQMQVTsETFBJ2yE6AuwRSM0fQJNbdqCAE79m+f81hVyCDl+rOdwN3xeDwAvPm29AEA6yIKQIgCEKIAhCgAIQpAiAIQogCEKABxv/QBHA6Hw/dvP7ytEmb26/fPu4/8P0sBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAGI+6UPgHV5fn2Z9es/PTzO+vX5OksBCEthR+ZeAVMeg0WxHEsBCFFglZ5fX1axbPZIFHZi1BfYqMc9MlEAQhSAEAUgPJLcuC1ck799Dx5T3oalAIQoMAyPKW/D5cNGefHwWZYCEKIAhCgAIQobs4ebcVv//pYmCkCIAhCiwJD2cJm0FFEAQhQ2xJmTKYgCEKIAhM8+DMBlwXX+/Xn5yPV1LAUgLAWGdvoLWM4tqtN/txousxSAsBRW7NK9hLeznnsOTEkUBiYGf/lZTMflAxCWwoqd3hRzJuRWLAUgLIVB/O9RmvVQHjdOw1IYlCAwF1EAQhTYDL94ZRqiAIQbjYNxJrzMH6T9GksBCEthEBbC9SyGz7EU2DxBvY4oACEKA3Cm45ZEAQhRAMLThxVz2TAdTyI+zlIAQhSAEAUgRIFd8UnKy0QBCFEAQhSAEAUgvHlphdwIY0mWAhCiAIQosEsu0c4TBSBEAQhRAEIUgPA+hRVx84s1sBSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFFbk6eHx8PTwuPRhsHOiAIQoACEKQIgCu+TezXmiAIQoACEKK+TRJEsSBSBEAQhRAEIUgLhf+gA47/Rm4/Pry4JHsh1u4F5mKQBhKQzivTOcFfE+6+A6lgIQlsIGzH0mXOMScfafjyhwkRfgvrh8AEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAEAUgRAEIUQBCFIAQBSBEAQhRAEIUgBAFIEQBCFEAQhSAuDsej0sfA7AilgIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgAIQpAiAIQogCEKAAhCkCIAhCiAIQoACEKQIgCEKIAhCgA8QfWQY8SR/FLYgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(y[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 512x512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sz=512\n",
    "bs=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n",
    "datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)\n",
    "md = ImageData(PATH, datasets, bs, num_workers=4, classes=None)\n",
    "denorm = md.trn_ds.denorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_base = get_base()\n",
    "m = to_gpu(Unet34(m_base))\n",
    "models = UnetModel(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = ConvLearner(md, models)\n",
    "learn.opt_fn=optim.Adam\n",
    "learn.crit=nn.BCEWithLogitsLoss()\n",
    "learn.metrics=[accuracy_thresh(0.5),dice]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.freeze_to(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('128urn-0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2cd9f452b7dc40b58895f0ca5b876936",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=5), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice              \n",
      "    0      0.071421   0.02362    0.996459   0.991772  \n",
      "    1      0.070373   0.014013   0.996558   0.992602          \n",
      "    2      0.067895   0.011482   0.996705   0.992883          \n",
      "    3      0.070653   0.014256   0.996695   0.992771          \n",
      "    4      0.068621   0.013195   0.996993   0.993359          \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.013194938530288046, 0.996993034604996, 0.993358936574724]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lr,1,wds=wd, cycle_len=5,use_clr=(5,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('512urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.bn_freeze(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('512urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8768382fe1fe4b8e94e6647899e20575",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=8), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice              \n",
      "    0      0.06605    0.013602   0.997      0.993014  \n",
      "    1      0.066885   0.011252   0.997248   0.993563          \n",
      "    2      0.065796   0.009802   0.997223   0.993817          \n",
      "    3      0.065089   0.009668   0.997296   0.993744          \n",
      "    4      0.064552   0.011683   0.997269   0.993835          \n",
      "    5      0.065089   0.010553   0.997415   0.993827          \n",
      "    6      0.064303   0.009472   0.997431   0.994046          \n",
      "    7      0.062506   0.009623   0.997441   0.994118          \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.009623114736602894, 0.9974409020136273, 0.9941179137381296]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lrs/4,1,wds=wd, cycle_len=8,use_clr=(20,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('512urn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('512urn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(md.val_dl))\n",
    "py = to_np(learn.model(V(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABIxJREFUeJzt3NFt40YUQFF64SpUhZsIXEGqTAVCmnAVLiPKR7CA4LsbS7IsDofn/BkwoCFB3nmmJT2dTqcF4NyPtRcAjEcYgBAGIIQBCGEAQhiAEAYghAEIYQDiee0FLMuy/PHjT2+/hG/29z9/PV36uyYGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYgntdeAOs5vr8tr4eX5fj+9rDXfD28POy1uJ0w7NB5CB4Zhc9eTzTGIQw78ugIXOvn+gRifZ4x7MToUTi3pbXOShgYkjisSxh2YKs32VbXPQNhmJybi1sIAxDCMLEZpoUZjmGLhIHhicPjCQMQwjCp2XbZ2Y5ndMIAhDBMyO7KVwkDmyF4jyMMQAjDZGbfVWc/vlEIAxDCMJG97KZ7Oc41CQMQwjAJuyj3JAxACMME9jgt7PGYH8mXwW7Ax5vAl6Xy3UwMg/vVznh8f7NjLr+fGpyfrzMxDMzFfZ2P5+v8Z1PWdUwMQAgDu2D6uo4wbJi/pd3w38UzBjZPHO7PxDAwD8xYi4lhcOdxsDPeTmSvIwwbIhKXez28LMf3N0G4kTAwjY8REIXbecawQaYFvpswbIwo/J5zcz/CwFTE4T6EYUNc9DyKMDAdAf06YdgIFzuPJAwbIArX8zmSrxEGpiYOtxGGwbmwWYMwDEwUWIswMD2BvZ4wACEMQAjDoIy/9+V8XkcYgBAGIIQBCGFgNzxnuJwwACEMA7KzsTZhYFdE9zLCAIQwDMaOxgiEAQhhAEIYgBAGdsdznM8JAxDCAIQwACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDAPxHn5GIQxACAMQwjCQ18PL2kuAZVmEAfgFYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYRiIL2phFMIwEB+7ZhTCMBATA6MQhoGYGBiFMAAhDEAIA7vjT7bPCQMQwjAYuxkjEAYghIFdMZFdRhiAEIYB2dVYmzAAIQyDMjWwJmEYmDjcl/N5uee1F8D/O7+YffrydqJwHWHYkEsubvH4jxB8jTBM5rMbYqZwuPm/jzDszKg3089gvR5eluP727Dr3AthYAjnIRCF9fmvBBDCAIQwACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDEAIAxDCAIQwACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDEAIAxDCAIQwACEMQAgDEMIAxNPpdFp7DcBgTAxACAMQwgCEMAAhDEAIAxDCAIQwACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDEAIAxDCAIQwACEMQAgDEMIAxL+fjtP4pEiy4gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(py[0]>0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABF9JREFUeJzt3NFtG0cUQFHJUBWuwk0EriBVpgIhTagKlRHmIwgi+MYSSe16Z2bP+TNgyMPBzp1HyuDj5XJ5AHjry9ELAMYjDEAIAxDCAIQwACEMQAgDEMIAhDAA8XT0Ah4eHh5++/K7/34JO/vzrz8er/27JgYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCejl4A43h+fdntZ3//+m23n832hOFk9jz8W/y7AjIGbyVO5Kgo3OL59WWKda5OGE7CYeMWwsCQhOxYwnACsx6yWde9AmEAQhgWN/utO/v6ZyUMC3OouJcwMDyB+/WEAQhhWNRqt+xqr2d0wgCEMCxo1dt11dc1ImEAQhgWs/qtuvrrG4UwACEMTMfUsD9hWIgDw1aEAQhhWIRpgS0JA1MSwn0JwwIcErbmW6IH97ND79uU/9kb+7APE8PA3psETAn/799vmbY/n2NiYBk/xuDHP5surmdimNzZb8azv/69CAMQwsD0TA3bE4aJORD/sRfbEoaB+bCMo/itxODexsGteD+RvY0wTEQkbiMG9xMGliEE2/EZw4RMC+xNGIAQhsmYFn7O3mxHGFiKOGxDGCbiob+Offo8YZiEh/029utzhGECHvL72Lf7CQNLE4f7CMPgPNgcQRiAEIaBmRa2YR9vJwxACAOnYGq4jTAMyoPMkYQBCGHgNExh1xMGIIRhQG42jiYMQAgDp2Iau44wACEMg3GjMQJh4HTE92PCAIQwACEMQAgDEMIAhDAAIQwD8Ws0RiEMQAgDp2Q6e58wACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDEAIAxDCAIQwACEMQAgDEMIAhDAAIQxACAMQwgCEMAAhDEAIA6f0/eu3o5cwNGEAQhiAEAYghGEg3vcyCmEAQhiAEAYghAEIYRiMDyD3Z48/JgxACAMQwgCEMAzIe+D92NvrCAMQwjAoNxtHEoaBicO27Of1no5eAO97+zA/v74cuJK5icJthGEi7z3colFicD9hWMQth2DViAjBdoThhBwgPuLDRyCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEA4vFyuRy9BmAwJgYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEAQhiAEAYghAEIYQBCGIAQBiCEAQhhAEIYgBAGIIQBCGEA4m/jpL5hSyL6CwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(y[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1024x1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sz=1024\n",
    "bs=4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)\n",
    "datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)\n",
    "md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n",
    "denorm = md.trn_ds.denorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_base = get_base()\n",
    "m = to_gpu(Unet34(m_base))\n",
    "models = UnetModel(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = ConvLearner(md, models)\n",
    "learn.opt_fn=optim.Adam\n",
    "learn.crit=nn.BCEWithLogitsLoss()\n",
    "learn.metrics=[accuracy_thresh(0.5),dice]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('512urn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.freeze_to(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a864769b63b7438fa112cf68e2835a0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice                 \n",
      "    0      0.007656   0.008155   0.997247   0.99353   \n",
      "    1      0.004706   0.00509    0.998039   0.995437             \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.005090427414942828, 0.9980387706605215, 0.995437301104031]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lr,1, wds=wd, cycle_len=2,use_clr=(5,4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('1024urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('1024urn-tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.bn_freeze(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lrs = np.array([lr/200,lr/30,lr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice                 \n",
      "    0      0.005688   0.006135   0.997616   0.994616  \n",
      "    1      0.004412   0.005223   0.997983   0.995349             \n",
      "    2      0.004186   0.004975   0.99806    0.99554              \n",
      "    3      0.004016   0.004899   0.99812    0.995627             \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.004898778487196458, 0.9981196409180051, 0.9956271404784823]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ba1a0b69230449da669623edfd3f6c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss   <lambda>   dice                 \n",
      "    0      0.004169   0.004962   0.998049   0.995517  \n",
      "    1      0.004022   0.004595   0.99823    0.995818             \n",
      "    2      0.003772   0.004497   0.998215   0.995916             \n",
      "    3      0.003618   0.004435   0.998291   0.995991             \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.004434524739663753, 0.9982911745707194, 0.9959913929776539]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.fit(lrs/10,1, wds=wd,cycle_len=4,use_clr=(20,8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.sched.plot_loss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.save('1024urn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.load('1024urn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = next(iter(md.val_dl))\n",
    "py = to_np(learn.model(V(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABIlJREFUeJzt3dFN21AYgNGAmIIpWKJigk7ZCVCXYArGaPpQkBBQCF/s2Nc+5z3K1VX+z9chJFfH4/EA8F3XSy8AGJN4AIl4AIl4AIl4AIl4AIl4AIl4AIl4AMnN0gv4zI/rnz7+CjP7/efXVXmckweQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQiAeQ3Cy9ANbl4elxkee9v71b5HnpxIPD4bBcNE59fnFZH7ctLB6OU4ywxr0Rj50baShHWuseiMeOjTiMI655q8SD4QjIOojHThlAziUeDEn8liceO2TwmIJ4MCwRXJZ47IyBYyriwdDEcDniASTisSOu0kxJPBieKC5DPHbCgDE18QAS8WATnKwuTzx2wGAxB/FgM0TyssRj4wwUcxEPIBEPNsVJ63J8e/rAPhqU198ybpCYk3gM6LMoPDw97v5nCt7uwVeRpXHbwi49PD06mZ1JPDbKYPxjH+YjHhtkYE5nrzrxYJNeouC9jfl4w5TNcqqYl5PHgFxNp2MvO/EY1P3tnRc+i3LbMri3AXFU/5zgTkc82DzBmIfbFiARjw1xy8IlicdGCMf/2Zt5iMcGGI6v2aPpicfgDMXp7NW0xGNghoEliQe74l/xpyMegzIA57F/5xMPIBEPIBGPATlyT8M+nkc8gEQ8gEQ8BuOoPS372YkHkIgHu+f00YjHQLzIWRPxABLxABLxABLxGIT3O+Zlf79PPIBEPIBEPIBEPIBEPIBEPIBEPAbgz4iskXjAM5H+HvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEYwP3t3dJLgHfEA0jEA0jEA0jEA0jEYwB+BpE1Eg8gEQ8gEY8B+JwHayQeQCIeQCIe8Mzt4feIB5CIxyBcFVkb8QAS8QAS8YCD28JCPIBEPAbi6siaiMdgBIS1EI8B3d/eiciE7GVzs/QC6F6/6H3nRyMcnXhsxEdDICjvicV0xGPDTh2U15F5eczI4RGIyxAPPhw2A8hXvGEKJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJFfH43HpNQADcvIAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAEvEAkr87r79YklPCmgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(py[0]>0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABIRJREFUeJzt3NFRE0EAgOHgUAVV0IRDBVZpBYxNUAVlGB/QGRVR8+eOu937vkdmCJud23/3kpCb8/l8ArjUh60HAIxJPIBEPIBEPIBEPIBEPIBEPIBEPIBEPIDkdusB/M3HD598/BVW9uXr55vye04eQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQCIeQHK79QDYv8fnp9Ue++HufrXHZl3iwel0WjcQS/xdkdkfty1sFo5LjDDGoxGPgxtpUY401iMQD4YiIPshHgc26kIcddyzEY+DsgC5lngwJPHbnngckIXHEsSDYYngtsQDSMTjYGbbrWd7PiMRDyARjwOZdZee9XntnXgAiXgchN2ZpYkHUxDH9yceQCIeB3CUXfkoz3MvxANIxGNydmPWIh5MRSzfj3hMzEJiTb49fVB/CoNvGOc9iceA3jpR/Pj50SPy+Pz0yxwI7TrctnBIj89PbuuuJB6TsjBe/GsezFMnHkxJFNYnHkAiHhOy674wD+sSjwF5p2A55rITj0E93N278NmUz3kM7veAOKr/neAuRzyYmlisx23LRJw6XjMn6xEPIBGPSdhh32Zu1iEeE7A4/s0cLU88BmdRsBXxGJhwXMZ/0i5LPDgcAVmGeAzKAriO+bueeACJeACJeAzIkXsZ5vE64gEk4gEk4jEYR+1lmc9OPIBEPIBEPAbiiL0O89qIB5CIB5CIB5zcuhTiMQgXN3sjHkAiHkAiHkAiHkAiHvCdF6UvIx5AIh5AIh4DcJxmj8QDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMRjAA9391sPAV4RD/hOpC8jHkAiHkAiHkAiHoNwP87eiAeQiAecnOwK8QAS8RiI3ZE9EY/BCMjyzGkjHgN6uLt3wS/EPHa3Ww+A7ucL//H5acORjEk4riMek3hrIYjKrwRjOeIxubJYZgqOWKxHPHjFguN/eMEUSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSMQDSG7O5/PWYwAG5OQBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJOIBJN8ANEq7NH8OpbAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_img(y[0]);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}