{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.conv_learner import *\n", "from fastai.dataset import *\n", "from fastai.models.resnet import vgg_resnet50\n", "\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.cuda.set_device(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.backends.cudnn.benchmark=True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/carvana')\n", "MASKS_FN = 'train_masks.csv'\n", "META_FN = 'metadata.csv'\n", "masks_csv = pd.read_csv(PATH/MASKS_FN)\n", "meta_csv = pd.read_csv(PATH/META_FN)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_img(im, figsize=None, ax=None, alpha=None):\n", " if not ax: fig,ax = plt.subplots(figsize=figsize)\n", " ax.imshow(im, alpha=alpha)\n", " ax.set_axis_off()\n", " return ax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TRAIN_DN = 'train-128'\n", "MASKS_DN = 'train_masks-128'\n", "sz = 128\n", "bs = 64\n", "nw = 16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TRAIN_DN = 'train'\n", "MASKS_DN = 'train_masks_png'\n", "sz = 128\n", "bs = 64\n", "nw = 16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MatchedFilesDataset(FilesDataset):\n", " def __init__(self, fnames, y, transform, path):\n", " self.y=y\n", " assert(len(fnames)==len(y))\n", " super().__init__(fnames, transform, path)\n", " def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))\n", " def get_c(self): return 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_names = np.array([Path(TRAIN_DN)/o for o in masks_csv['img']])\n", "y_names = np.array([Path(MASKS_DN)/f'{o[:-4]}_mask.png' for o in masks_csv['img']])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_idxs = list(range(1008))\n", "((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),\n", " RandomFlip(tfm_y=TfmType.CLASS),\n", " RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)\n", "datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, path=PATH)\n", "md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)\n", "denorm = md.trn_ds.denorm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = next(iter(md.trn_dl))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 3, 128, 128]), torch.Size([64, 128, 128]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape,y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple upsample" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = resnet34\n", "cut,lr_cut = model_meta[f]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_base():\n", " layers = cut_model(f(True), cut)\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dice(pred, targs):\n", " pred = (pred>0).float()\n", " return 2. * (pred*targs).sum() / (pred+targs).sum()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class StdUpsample(nn.Module):\n", " def __init__(self, nin, nout):\n", " super().__init__()\n", " self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)\n", " self.bn = nn.BatchNorm2d(nout)\n", " \n", " def forward(self, x): return self.bn(F.relu(self.conv(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Upsample34(nn.Module):\n", " def __init__(self, rn):\n", " super().__init__()\n", " self.rn = rn\n", " self.features = nn.Sequential(\n", " rn, nn.ReLU(),\n", " StdUpsample(512,256),\n", " StdUpsample(256,256),\n", " StdUpsample(256,256),\n", " StdUpsample(256,256),\n", " nn.ConvTranspose2d(256, 1, 2, stride=2))\n", " \n", " def forward(self,x): return self.features(x)[:,0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class UpsampleModel():\n", " def __init__(self,model,name='upsample'):\n", " self.model,self.name = model,name\n", "\n", " def get_layer_groups(self, precompute):\n", " lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))\n", " return lgs + [children(self.model.features)[1:]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m_base = get_base()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = to_gpu(Upsample34(m_base))\n", "models = UpsampleModel(m)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(md, models)\n", "learn.opt_fn=optim.Adam\n", "learn.crit=nn.BCEWithLogitsLoss()\n", "learn.metrics=[accuracy_thresh(0.5),dice]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.freeze_to(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4667e9fa899453da7cbf0a0524fb426", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 86%|█████████████████████████████████████████████████████████████ | 55/64 [00:22<00:03, 2.46it/s, loss=3.21]" ] }, { "data": { "image/png": "\n", "text/plain": [ "