{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Super resolution on Imagenet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai.vision import *\n", "from fastai.callbacks import *\n", "from fastai.utils.mem import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision.models import vgg16_bn" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.cuda.set_device(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Path('data/imagenet')\n", "path_hr = path/'train'\n", "path_lr = path/'small-64/train'\n", "path_mr = path/'small-256/train'\n", "\n", "# note: this notebook relies on models created by lesson7-superres.ipynb\n", "path_pets = untar_data(URLs.PETS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "il = ImageList.from_folder(path_hr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def resize_one(fn, i, path, size):\n", " dest = path/fn.relative_to(path_hr)\n", " dest.parent.mkdir(parents=True, exist_ok=True)\n", " img = PIL.Image.open(fn)\n", " targ_sz = resize_to(img, size, use_min=True)\n", " img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')\n", " img.save(dest, quality=60)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert path.exists(), f\"need imagenet dataset @ {path}\"\n", "# create smaller image sets the first time this nb is run\n", "sets = [(path_lr, 64), (path_mr, 256)]\n", "for p,size in sets:\n", " if not p.exists(): \n", " print(f\"resizing to {size} into {p}\")\n", " parallel(partial(resize_one, path=p, size=size), il.items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "using bs=8, size=256, have 8109MB of GPU RAM free\n" ] } ], "source": [ "free = gpu_mem_get_free_no_cache()\n", "# the max size of the test image depends on the available GPU RAM \n", "if free > 8200: bs,size=16,256 \n", "else: bs,size=8,256\n", "print(f\"using bs={bs}, size={size}, have {free}MB of GPU RAM free\")\n", "\n", "arch = models.resnet34\n", "# sample = 0.1\n", "sample = False\n", "\n", "tfms = get_transforms()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "src = ImageImageList.from_folder(path_lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if sample: src = src.filter_by_rand(sample, seed=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "src = src.split_by_rand_pct(0.1, seed=42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs,size):\n", " data = (src.label_from_func(lambda x: path_hr/x.relative_to(path_lr))\n", " .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)\n", " .databunch(bs=bs).normalize(imagenet_stats, do_y=True))\n", "\n", " data.c = 3\n", " return data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = get_data(bs,size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def gram_matrix(x):\n", " n,c,h,w = x.size()\n", " x = x.view(n, c, -1)\n", " return (x @ x.transpose(1,2))/(c*h*w)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vgg_m = vgg16_bn(True).features.cuda().eval()\n", "requires_grad(vgg_m, False)\n", "blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "base_loss = F.l1_loss\n", "\n", "class FeatureLoss(nn.Module):\n", " def __init__(self, m_feat, layer_ids, layer_wgts):\n", " super().__init__()\n", " self.m_feat = m_feat\n", " self.loss_features = [self.m_feat[i] for i in layer_ids]\n", " self.hooks = hook_outputs(self.loss_features, detach=False)\n", " self.wgts = layer_wgts\n", " self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))\n", " ] + [f'gram_{i}' for i in range(len(layer_ids))]\n", "\n", " def make_features(self, x, clone=False):\n", " self.m_feat(x)\n", " return [(o.clone() if clone else o) for o in self.hooks.stored]\n", " \n", " def forward(self, input, target):\n", " out_feat = self.make_features(target, clone=True)\n", " in_feat = self.make_features(input)\n", " self.feat_losses = [base_loss(input,target)]\n", " self.feat_losses += [base_loss(f_in, f_out)*w\n", " for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n", " self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3\n", " for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n", " self.metrics = dict(zip(self.metric_names, self.feat_losses))\n", " return sum(self.feat_losses)\n", " \n", " def __del__(self): self.hooks.remove()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wd = 1e-3\n", "learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics, blur=True, norm_type=NormType.Weight)\n", "gc.collect();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# relies on first running lesson7-superres.ipynb which created the following model\n", "learn.load((path_pets/'small-96'/'models'/'2b').absolute());" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 17:52
epoch | \n", "train_loss | \n", "valid_loss | \n", "pixel | \n", "feat_0 | \n", "feat_1 | \n", "feat_2 | \n", "gram_0 | \n", "gram_1 | \n", "gram_2 | \n", "
---|---|---|---|---|---|---|---|---|---|
1 | \n", "2.347123 | \n", "2.385141 | \n", "0.229566 | \n", "0.293816 | \n", "0.322328 | \n", "0.146045 | \n", "0.460049 | \n", "0.638133 | \n", "0.295204 | \n", "