{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.torch_basics import *\n", "from local.test import *\n", "from local.callback.hook import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp vision.models.unet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dynamic UNet\n", "\n", "> Unet model using PixelShuffle ICNR upsampling that can be built on top of any pretrained architecture" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "def _get_sz_change_idxs(sizes):\n", " \"Get the indexes of the layers where the size of the activation changes.\"\n", " feature_szs = [size[-1] for size in sizes]\n", " sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])\n", " if feature_szs[0] != feature_szs[1]: sz_chg_idxs = [0] + sz_chg_idxs\n", " return sz_chg_idxs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(_get_sz_change_idxs([[3,64,64], [16,64,64], [32,32,32], [16,32,32], [32,32,32], [16,16]]), [1,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class UnetBlock(Module):\n", " \"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.\"\n", " @delegates(ConvLayer.__init__)\n", " def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,\n", " self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):\n", " self.hook = hook\n", " self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)\n", " self.bn = BatchNorm(x_in_c)\n", " ni = up_in_c//2 + x_in_c\n", " nf = ni if final_div else ni//2\n", " self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)\n", " self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type, xtra=SelfAttention(nf) if self_attention else None, **kwargs)\n", " self.relu = act_cls()\n", " apply_init(nn.Sequential(self.conv1, self.conv2), init)\n", "\n", " def forward(self, up_in):\n", " s = self.hook.stored\n", " up_out = self.shuf(up_in)\n", " ssh = s.shape[-2:]\n", " if ssh != up_out.shape[-2:]:\n", " up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')\n", " cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))\n", " return self.conv2(self.conv1(cat_x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Check against v1 implementation\n", "#TODO: remove when v2 is the official version\n", "from fastai.vision.models.unet import UnetBlock as OldUnetBlock\n", "\n", "source = ConvLayer(5, 10, ks=3)\n", "hook = hook_output(source)\n", "\n", "mod1 = UnetBlock(8, 10, hook)\n", "mod2 = OldUnetBlock(8, 10, hook, norm_type=None)\n", "sd1,sd2 = mod1.state_dict(),mod2.state_dict()\n", "for k1,k2 in zip(sd1.keys(), sd2.keys()): sd2[k2] = sd1[k1].clone()\n", "mod2.load_state_dict(sd2)\n", "x1 = torch.randn(16, 5, 8, 8)\n", "x2 = torch.randn(16, 8, 16, 16)\n", "_ = source(x1)\n", "y1 = mod1(x2.clone())\n", "y2 = mod2(x2.clone())\n", "test_close(y1, y2)\n", "hook.remove()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class DynamicUnet(SequentialEx):\n", " \"Create a U-Net from a given architecture.\"\n", " def __init__(self, encoder, n_classes, img_size, blur=False, blur_final=True, self_attention=False,\n", " y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,\n", " init=nn.init.kaiming_normal_, norm_type=NormType.Batch, **kwargs):\n", " imsize = img_size\n", " sizes = model_sizes(encoder, size=imsize)\n", " sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))\n", " self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)\n", " x = dummy_eval(encoder, imsize).detach()\n", "\n", " ni = sizes[-1][1]\n", " middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),\n", " ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()\n", " x = middle_conv(x)\n", " layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]\n", "\n", " for i,idx in enumerate(sz_chg_idxs):\n", " not_final = i!=len(sz_chg_idxs)-1\n", " up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])\n", " do_blur = blur and (not_final or blur_final)\n", " sa = self_attention and (i==len(sz_chg_idxs)-3)\n", " unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,\n", " act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()\n", " layers.append(unet_block)\n", " x = unet_block(x)\n", "\n", " ni = x.shape[1]\n", " if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))\n", " x = PixelShuffle_ICNR(ni)(x)\n", " if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest')))\n", " if last_cross:\n", " layers.append(MergeLayer(dense=True))\n", " ni += in_channels(encoder)\n", " layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))\n", " layers += [ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]\n", " apply_init(nn.Sequential(layers[3], layers[-2]), init)\n", " #apply_init(nn.Sequential(layers[2]), init)\n", " if y_range is not None: layers.append(SigmoidRange(*y_range))\n", " super().__init__(*layers)\n", "\n", " def __del__(self):\n", " if hasattr(self, \"sfs\"): self.sfs.remove()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.vision.models import resnet34\n", "m = resnet34()\n", "m = nn.Sequential(*list(m.children())[:-2])\n", "tst = DynamicUnet(m, 5, (128,128), norm_type=None)\n", "x = torch.randn(2, 3, 128, 128)\n", "y = tst(x)\n", "test_eq(y.shape, [2, 5, 128, 128])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#slow\n", "#Check against v1 implementation\n", "#TODO: remove when v2 is the official version\n", "from fastai.vision.models.unet import DynamicUnet as OldDynamicUnet\n", "\n", "encoder = nn.Sequential(*list(resnet34(True).children())[:-2])\n", "mod1 = DynamicUnet(encoder, 5, (128,128), norm_type=None)\n", "mod2 = OldDynamicUnet(encoder, 5, (128,128), norm_type=None)\n", "sd1,sd2 = mod1.state_dict(),mod2.state_dict()\n", "for k1,k2 in zip(sd1.keys(), sd2.keys()): sd2[k2] = sd1[k1].clone()\n", "mod2.load_state_dict(sd2)\n", "x = torch.randn(2, 3, 128, 128)\n", "y1 = mod1(x.clone())\n", "y2 = mod2(x.clone())\n", "\n", "#ResBlock in v2 have the ReLU after the merge so don't give the same results\n", "y1 = SequentialEx(*mod1.layers[:10])(x.clone())\n", "y2 = SequentialEx(*mod2.layers[:10])(x.clone())\n", "test_close(y1, y2, eps=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core.ipynb.\n", "Converted 01a_utils.ipynb.\n", "Converted 01b_dispatch.ipynb.\n", "Converted 01c_transform.ipynb.\n", "Converted 02_script.ipynb.\n", "Converted 03_torch_core.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_dataloader.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import *\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }