{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import * # Quick access to most common functionality\n", "from fastai.vision import * # Quick access to computer vision functionality\n", "from fastai.callbacks import *\n", "from torchvision.models import vgg16_bn" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/')\n", "PATH_TRN = PATH/'train'\n", "\n", "sz_lr=72\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classes = list(PATH_TRN.iterdir())\n", "fnames_full = []\n", "for class_folder in progress_bar(classes):\n", " for fname in class_folder.iterdir():\n", " fnames_full.append(fname)\n", "\n", "np.random.seed(42)\n", "keep_pct = 0.02\n", "#keep_pct = 1.\n", "keeps = np.random.rand(len(fnames_full)) < keep_pct\n", "image_fns = np.array(fnames_full, copy=False)[keeps]\n", "len(image_fns)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_pct=0.1\n", "src = (ImageToImageList(image_fns)\n", " .random_split_by_pct(valid_pct, seed=42)\n", " .label_from_func(lambda x: x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs, sz_lr, sz_hr, num_workers=12, **kwargs):\n", "# tfms = get_transforms(do_flip=True, flip_vert=True,\n", "# max_lighting=0.0, max_rotate=0.0,max_zoom=0.0,max_warp=0.0) \n", " tfms = [[dihedral_affine(p=0.75), crop_pad(row_pct=0.5, col_pct=0.5)],\n", " [crop_pad(row_pct=0.5, col_pct=0.5)]]\n", " data = (src\n", " .transform(tfms, size=sz_lr)\n", " .transform_labels(size=sz_hr)\n", " .databunch(bs=bs, num_workers=num_workers, **kwargs)\n", " .normalize(imagenet_stats, do_y=True))\n", " return data\n", "\n", "sz_lr = 72\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.train_ds[0:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = data.dl().one_batch()\n", "x.shape, y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_img(x, idx=0):\n", " return Image(torch.clamp(data.denorm(x.cpu()),0,1)[idx])\n", "\n", "idx=5\n", "x_img = make_img(x, idx)\n", "y_img = make_img(y, idx)\n", "x_img.show(), y_img.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wn = lambda x: torch.nn.utils.weight_norm(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv(ni, nf, kernel_size=3, actn=True):\n", " layers = [wn(nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2))]\n", " if actn: layers.append(nn.ReLU(True))\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ResSequential(nn.Module):\n", " def __init__(self, layers, res_scale=1.0):\n", " super().__init__()\n", " self.res_scale = res_scale\n", " self.m = nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " x = x + self.m(x) * self.res_scale\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def res_block(nf):\n", " return ResSequential(\n", " [conv(nf, nf), conv(nf, nf, actn=False)],\n", " 0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def upsample(ni, nf, scale):\n", " layers = []\n", " for i in range(int(math.log(scale,2))):\n", " layers += [conv(ni, nf*4), nn.PixelShuffle(2)]\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SrResnet(nn.Module):\n", " def __init__(self, nf, scale, n_res=8):\n", " super().__init__()\n", " features = [conv(3, 64)]\n", " for i in range(n_res): features.append(res_block(64))\n", " features += [conv(64,64), upsample(64, 64, scale),\n", " # nn.BatchNorm2d(64),\n", " conv(64, 3, actn=False)]\n", " self.features = nn.Sequential(*features)\n", " \n", " def forward(self, x): return self.features(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def icnr(x, scale, init=nn.init.kaiming_normal_):\n", " new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])\n", " subkernel = torch.zeros(new_shape)\n", " subkernel = init(subkernel)\n", " subkernel = subkernel.transpose(0, 1)\n", " subkernel = subkernel.contiguous().view(subkernel.shape[0],\n", " subkernel.shape[1], -1)\n", " kernel = subkernel.repeat(1, 1, scale ** 2)\n", " transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])\n", " kernel = kernel.contiguous().view(transposed_shape)\n", " kernel = kernel.transpose(0, 1)\n", " return kernel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = SrResnet(64, scale)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#model = torch.load('old.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# wd=1e-7\n", "# learn = Learner(data, nn.DataParallel(model,[0,2]), loss_func=F.mse_loss, opt_func=torch.optim.Adam, wd=wd, true_wd=False)\n", "\n", "sz_lr = 288\n", "scale,bs = 4,12\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('pixel_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-3\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 1e-3\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-4\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('pixel_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 72\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)\n", "\n", "learn = learn.load('pixel_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_x_y_pred(x, pred, y, figsize):\n", " rows=x.shape[0]\n", " fig, axs = plt.subplots(rows,3,figsize=figsize)\n", " for i in range(rows):\n", " make_img(x, i).show(ax=axs[i, 0])\n", " make_img(pred, i).show(ax=axs[i, 1])\n", " make_img(y, i).show(ax=axs[i, 2])\n", " plt.tight_layout() " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(learn.data.valid_dl))\n", "y_pred = model(x)\n", "x[0:3].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "make_img(y_pred.detach(), 2).show(), make_img(y.detach(), 2).show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_some(learn, do_denorm=True):\n", " x, y = next(iter(learn.data.valid_dl))\n", " y_pred = model(x)\n", " y_pred = y_pred.detach()\n", " x = x.detach()\n", " y = y.detach()\n", " plot_x_y_pred(x[0:3], y_pred[0:3], y[0:3], figsize=y_pred.shape[-2:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_some(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m_vgg_feat = vgg16_bn(True).cuda().eval().features\n", "requires_grad(m_vgg_feat, False)\n", "\n", "blocks = [i-1 for i,o in enumerate(children(m_vgg_feat))\n", " if isinstance(o,nn.MaxPool2d)]\n", "blocks, [m_vgg_feat[i] for i in blocks]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeatureLoss(nn.Module):\n", " def __init__(self, m_feat, layer_ids, layer_wgts):\n", " super().__init__()\n", " self.m_feat = m_feat\n", " self.loss_features = [self.m_feat[i] for i in layer_ids]\n", " self.hooks = hook_outputs(self.loss_features, detach=False)\n", " self.wgts = layer_wgts\n", " self.metrics = {}\n", " self.metric_names = ['L1'] + [f'feat_{i}' for i in range(len(layer_ids))]\n", " for name in self.metric_names: self.metrics[name] = 0.\n", "\n", " def make_feature(self, bs, o, clone=False):\n", " feat = o.view(bs, -1)\n", " if clone: feat = feat.clone()\n", " return feat\n", " \n", " def make_features(self, x, clone=False):\n", " bs = x.shape[0]\n", " self.m_feat(x)\n", " return [self.make_feature(bs, o, clone) for o in self.hooks.stored]\n", " \n", " def forward(self, input, target):\n", " out_feat = self.make_features(target, clone=True)\n", " in_feat = self.make_features(input)\n", " l1_loss = F.l1_loss(input,target)/100\n", " self.feat_losses = [l1_loss]\n", " self.feat_losses += [F.mse_loss(f_in, f_out)*w\n", " for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n", " for i,name in enumerate(self.metric_names): self.metrics[name] = self.feat_losses[i]\n", " self.metrics['L1'] = l1_loss\n", " self.loss = sum(self.feat_losses)\n", " return self.loss\n", " \n", "class ReportLossMetrics(LearnerCallback):\n", " _order = -20 #Needs to run before the recorder\n", " \n", " def on_train_begin(self, **kwargs):\n", " self.metric_names = self.learn.loss_func.metric_names\n", " self.learn.recorder.add_metric_names(self.metric_names)\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.metrics = {}\n", " for name in self.metric_names:\n", " self.metrics[name] = 0.\n", " self.nums = 0\n", " \n", " def on_batch_end(self, last_target, train, **kwargs):\n", " if not train:\n", " bs = last_target.size(0)\n", " for name in self.metric_names:\n", " self.metrics[name] += bs * self.learn.loss_func.metrics[name]\n", " self.nums += bs\n", " \n", " def on_epoch_end(self, **kwargs):\n", " if self.nums:\n", " metrics = [self.metrics[name]/self.nums for name in self.metric_names]\n", " self.learn.recorder.add_metrics(metrics)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 32\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale\n", "\n", "data = get_data(bs, sz_lr, sz_hr)\n", "\n", "feat_loss = FeatureLoss(m_vgg_feat, blocks[0:4], [0.25,0.25,0.25,0.25])\n", "\n", "model = SrResnet(64, scale)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])\n", "learn.load('pixel_v2')\n", "\n", "model = learn.model.module\n", "nres = 8\n", "conv_shuffle = model.features[nres+2][0][0]\n", "kernel = icnr(conv_shuffle.weight, scale=scale)\n", "conv_shuffle.weight.data.copy_(kernel);\n", "\n", "conv_shuffle = model.features[nres+2][2][0]\n", "kernel = icnr(conv_shuffle.weight, scale=scale)\n", "conv_shuffle.weight.data.copy_(kernel);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.freeze_to(999)\n", "for i in range(10,12): requires_grad(model.features[i], True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 128\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "\n", "data = get_data(bs, sz_lr, sz_hr)\n", "model = SrResnet(64, scale)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])\n", "learn = learn.load('enhance_feat_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=1e-3\n", "learn.unfreeze()\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('enhance_feat_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('enhance_feat_v2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.unfreeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=1e-5\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('enhance_feat2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('enhance_feat2')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 72\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)\n", "learn = learn.load('enhance_feat_v2')\n", "plot_some(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 72\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)\n", "learn = learn.load('pixel_v2')\n", "plot_some(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }