In [None]:
#export
from local.torch_basics import *
from local.test import *
from local.callback.hook import *

In [None]:
from local.notebook.showdoc import *

In [None]:
# default_exp vision.models.unet

# Dynamic UNet

> Unet model using PixelShuffle ICNR upsampling that can be built on top of any pretrained architecture

In [None]:
#export 
def _get_sz_change_idxs(sizes):
 "Get the indexes of the layers where the size of the activation changes."
 feature_szs = [size[-1] for size in sizes]
 sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
 if feature_szs[0] != feature_szs[1]: sz_chg_idxs = [0] + sz_chg_idxs
 return sz_chg_idxs

In [None]:
#hide
test_eq(_get_sz_change_idxs([[3,64,64], [16,64,64], [32,32,32], [16,32,32], [32,32,32], [16,16]]), [1,4])

In [None]:
#export 
class UnetBlock(Module):
 "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
 @delegates(ConvLayer.__init__)
 def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
 self.hook = hook
 self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
 self.bn = BatchNorm(x_in_c)
 ni = up_in_c//2 + x_in_c
 nf = ni if final_div else ni//2
 self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
 self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type, xtra=SelfAttention(nf) if self_attention else None, **kwargs)
 self.relu = act_cls()
 apply_init(nn.Sequential(self.conv1, self.conv2), init)

 def forward(self, up_in):
 s = self.hook.stored
 up_out = self.shuf(up_in)
 ssh = s.shape[-2:]
 if ssh != up_out.shape[-2:]:
 up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
 cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
 return self.conv2(self.conv1(cat_x))

In [None]:
#hide
#Check against v1 implementation
#TODO: remove when v2 is the official version
from fastai.vision.models.unet import UnetBlock as OldUnetBlock

source = ConvLayer(5, 10, ks=3)
hook = hook_output(source)

mod1 = UnetBlock(8, 10, hook)
mod2 = OldUnetBlock(8, 10, hook, norm_type=None)
sd1,sd2 = mod1.state_dict(),mod2.state_dict()
for k1,k2 in zip(sd1.keys(), sd2.keys()): sd2[k2] = sd1[k1].clone()
mod2.load_state_dict(sd2)
x1 = torch.randn(16, 5, 8, 8)
x2 = torch.randn(16, 8, 16, 16)
_ = source(x1)
y1 = mod1(x2.clone())
y2 = mod2(x2.clone())
test_close(y1, y2)
hook.remove()

In [None]:
#export 
class DynamicUnet(SequentialEx):
 "Create a U-Net from a given architecture."
 def __init__(self, encoder, n_classes, img_size, blur=False, blur_final=True, self_attention=False,
 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
 init=nn.init.kaiming_normal_, norm_type=NormType.Batch, **kwargs):
 imsize = img_size
 sizes = model_sizes(encoder, size=imsize)
 sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
 self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
 x = dummy_eval(encoder, imsize).detach()

 ni = sizes[-1][1]
 middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
 ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
 x = middle_conv(x)
 layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]

 for i,idx in enumerate(sz_chg_idxs):
 not_final = i!=len(sz_chg_idxs)-1
 up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
 do_blur = blur and (not_final or blur_final)
 sa = self_attention and (i==len(sz_chg_idxs)-3)
 unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
 act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
 layers.append(unet_block)
 x = unet_block(x)

 ni = x.shape[1]
 if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
 x = PixelShuffle_ICNR(ni)(x)
 if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest')))
 if last_cross:
 layers.append(MergeLayer(dense=True))
 ni += in_channels(encoder)
 layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
 layers += [ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
 apply_init(nn.Sequential(layers[3], layers[-2]), init)
 #apply_init(nn.Sequential(layers[2]), init)
 if y_range is not None: layers.append(SigmoidRange(*y_range))
 super().__init__(*layers)

 def __del__(self):
 if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
from local.vision.models import resnet34
m = resnet34()
m = nn.Sequential(*list(m.children())[:-2])
tst = DynamicUnet(m, 5, (128,128), norm_type=None)
x = torch.randn(2, 3, 128, 128)
y = tst(x)
test_eq(y.shape, [2, 5, 128, 128])

In [None]:
#hide
#slow
#Check against v1 implementation
#TODO: remove when v2 is the official version
from fastai.vision.models.unet import DynamicUnet as OldDynamicUnet

encoder = nn.Sequential(*list(resnet34(True).children())[:-2])
mod1 = DynamicUnet(encoder, 5, (128,128), norm_type=None)
mod2 = OldDynamicUnet(encoder, 5, (128,128), norm_type=None)
sd1,sd2 = mod1.state_dict(),mod2.state_dict()
for k1,k2 in zip(sd1.keys(), sd2.keys()): sd2[k2] = sd1[k1].clone()
mod2.load_state_dict(sd2)
x = torch.randn(2, 3, 128, 128)
y1 = mod1(x.clone())
y2 = mod2(x.clone())

#ResBlock in v2 have the ReLU after the merge so don't give the same results
y1 = SequentialEx(*mod1.layers[:10])(x.clone())
y2 = SequentialEx(*mod2.layers[:10])(x.clone())
test_close(y1, y2, eps=1e-3)

## Export -

In [None]:
#hide
from local.notebook.export import *
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_utils.ipynb.
Converted 01b_dispatch.ipynb.
Converted 01c_transform.ipynb.
Converted 02_script.ipynb.
Converted 03_torch_core.ipynb.
Converted 03a_layers.ipynb.
Converted 04_dataloader.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_data_block.ipynb.
Converted 08_vision_core.ipynb.
Converted 09_vision_augment.ipynb.
Converted 10_pets_tutorial.ipynb.
Converted 11_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 13a_metrics.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 21_vision_learner.ipynb.
Converted 22_tutorial_imagenette.ipynb.
Converted 23_tutorial_transfer_learning.ipynb.
Conve