{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastai\n", "from fastai import * # Quick access to most common functionality\n", "from fastai.vision import * # Quick access to computer vision functionality\n", "from fastai.callbacks import *\n", "from torchvision.models import vgg16_bn" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/')\n", "PATH_TRN = PATH/'train'\n", "\n", "sz_lr=224//4\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classes = list(PATH_TRN.iterdir())\n", "fnames_full = []\n", "for class_folder in progress_bar(classes):\n", " for fname in class_folder.iterdir():\n", " fnames_full.append(fname)\n", "\n", "np.random.seed(42)\n", "keep_pct = 0.02\n", "#keep_pct = 1.\n", "keeps = np.random.rand(len(fnames_full)) < keep_pct\n", "image_fns = np.array(fnames_full, copy=False)[keeps]\n", "len(image_fns)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_pct=0.1\n", "src = (ImageToImageList(image_fns)\n", " .random_split_by_pct(valid_pct, seed=42)\n", " .label_from_func(lambda x: x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs, sz_lr, sz_hr, num_workers=12, **kwargs):\n", " tfms = get_transforms(flip_vert=True)\n", " data = (src\n", " .transform(tfms, size=sz_lr)\n", " .transform_labels(size=sz_hr)\n", " .databunch(bs=bs, num_workers=num_workers, **kwargs))\n", " #.normalize(imagenet_stats, do_y=True))\n", " return data\n", "\n", "sz_lr = 288//4\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.train_ds[0:3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = data.dl().one_batch()\n", "x.shape, y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Block(nn.Module):\n", " def __init__(self, n_feats, kernel_size, wn, act=nn.ReLU(True), res_scale=1):\n", " super(Block, self).__init__()\n", " self.res_scale = res_scale\n", " body = []\n", " expand = 6\n", " linear = 0.8\n", " body.append(\n", " wn(nn.Conv2d(n_feats, n_feats*expand, 1, padding=1//2)))\n", " body.append(act)\n", " body.append(\n", " wn(nn.Conv2d(n_feats*expand, int(n_feats*linear), 1, padding=1//2)))\n", " body.append(\n", " wn(nn.Conv2d(int(n_feats*linear), n_feats, kernel_size, padding=kernel_size//2)))\n", "\n", " self.body = nn.Sequential(*body)\n", "\n", " def forward(self, x):\n", " res = self.body(x) * self.res_scale\n", " res += x\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class WDSR(nn.Module):\n", " def __init__(self, scale, n_resblocks, n_feats, res_scale, n_colors=3):\n", " super().__init__()\n", " # hyper-params\n", " kernel_size = 3\n", " act = nn.ReLU(True)\n", " # wn = lambda x: x\n", " wn = lambda x: torch.nn.utils.weight_norm(x)\n", "\n", " mean, std = imagenet_stats\n", " self.rgb_mean = torch.autograd.Variable(torch.FloatTensor(mean)).view([1, n_colors, 1, 1])\n", " self.rgb_std = torch.autograd.Variable(torch.FloatTensor(std)).view([1, n_colors, 1, 1])\n", "\n", " # define head module\n", " head = []\n", " head.append(\n", " wn(nn.Conv2d(n_colors, n_feats,3,padding=3//2)))\n", "\n", " # define body module\n", " body = []\n", " for i in range(n_resblocks):\n", " body.append(\n", " Block(n_feats, kernel_size, act=act, res_scale=res_scale, wn=wn))\n", "\n", " # define tail module\n", " tail = []\n", " out_feats = scale*scale*n_colors\n", " tail.append(\n", " wn(nn.Conv2d(n_feats, out_feats, 3, padding=3//2)))\n", " tail.append(nn.PixelShuffle(scale))\n", "\n", " skip = []\n", " skip.append(\n", " wn(nn.Conv2d(n_colors, out_feats, 5, padding=5//2))\n", " )\n", " skip.append(nn.PixelShuffle(scale))\n", "\n", " pad = []\n", " pad.append(torch.nn.ReplicationPad2d(5//2))\n", "\n", " # make object members\n", " self.head = nn.Sequential(*head)\n", " self.body = nn.Sequential(*body)\n", " self.tail = nn.Sequential(*tail)\n", " self.skip = nn.Sequential(*skip)\n", " self.pad = nn.Sequential(*pad)\n", "\n", " def forward(self, x):\n", " mean = self.rgb_mean.to(x)\n", " std = self.rgb_std.to(x)\n", " \n", " x = (x - mean) / std\n", " #if not self.training: \n", " # x = self.pad(x)\n", " \n", " s = self.skip(x)\n", " x = self.head(x)\n", " x = self.body(x)\n", " x = self.tail(x)\n", " x += s\n", " x = x*std + mean\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scale=4\n", "n_resblocks=8\n", "n_feats=64\n", "res_scale= 1.\n", "model = WDSR(scale, n_resblocks, n_feats, res_scale).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 72\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "#loss = CropTargetForLoss(F.l1_loss)\n", "loss = F.mse_loss\n", "learn = Learner(data, nn.DataParallel(model), loss_func=loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# learn.lr_find(num_it=500, start_lr=1e-5, end_lr=1000)\n", "# learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#learn.load('pixel')\n", "lr = 1e-3\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('pixel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(1, lr/5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('pixel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 512\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "#loss = CropTargetForLoss(F.l1_loss)\n", "loss = F.mse_loss\n", "learn = Learner(data, nn.DataParallel(model), loss_func=loss)\n", "learn = learn.load('pixel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('pixel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m_vgg_feat = vgg16_bn(True).features.cuda().eval().features\n", "requires_grad(m_vgg_feat, False)\n", "\n", "blocks = [i-1 for i,o in enumerate(children(m_vgg_feat))\n", " if isinstance(o,nn.MaxPool2d)]\n", "blocks, [m_vgg_feat[i] for i in blocks]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FeatureLoss(nn.Module):\n", " def __init__(self, m_feat, layer_ids, layer_wgts):\n", " super().__init__()\n", " self.m_feat = m_feat\n", " self.loss_features = [self.m_feat[i] for i in layer_ids]\n", " self.hooks = hook_outputs(self.loss_features, detach=False)\n", " self.wgts = layer_wgts\n", " self.metrics = {}\n", " self.metric_names = ['L1'] + [f'feat_{i}' for i in range(len(layer_ids))]\n", " for name in self.metric_names: self.metrics[name] = 0.\n", "\n", " def make_feature(self, bs, o, clone=False):\n", " feat = o.view(bs, -1)\n", " if clone: feat = feat.clone()\n", " return feat\n", " \n", " def make_features(self, x, clone=False):\n", " bs = x.shape[0]\n", " self.m_feat(x)\n", " return [self.make_feature(bs, o, clone) for o in self.hooks.stored]\n", " \n", " def forward(self, input, target):\n", " out_feat = self.make_features(target, clone=True)\n", " in_feat = self.make_features(input)\n", " l1_loss = F.l1_loss(input,target)/100\n", " self.feat_losses = [l1_loss]\n", " self.feat_losses += [F.mse_loss(f_in, f_out)*w\n", " for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]\n", " for i,name in enumerate(self.metric_names): self.metrics[name] = self.feat_losses[i]\n", " self.metrics['L1'] = l1_loss\n", " self.loss = sum(self.feat_losses)\n", " return self.loss*100\n", " \n", "class ReportLossMetrics(LearnerCallback):\n", " _order = -20 #Needs to run before the recorder\n", " \n", " def on_train_begin(self, **kwargs):\n", " self.metric_names = self.learn.loss_func.metric_names\n", " self.learn.recorder.add_metric_names(self.metric_names)\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.metrics = {}\n", " for name in self.metric_names:\n", " self.metrics[name] = 0.\n", " self.nums = 0\n", " \n", " def on_batch_end(self, last_target, train, **kwargs):\n", " if not train:\n", " bs = last_target.size(0)\n", " for name in self.metric_names:\n", " self.metrics[name] += bs * self.learn.loss_func.metrics[name]\n", " self.nums += bs\n", " \n", " def on_epoch_end(self, **kwargs):\n", " if self.nums:\n", " metrics = [self.metrics[name]/self.nums for name in self.metric_names]\n", " self.learn.recorder.add_metrics(metrics)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 200\n", "scale,bs = 4,4\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "feat_loss = FeatureLoss(m_vgg_feat, blocks[:2], [0.25,0.45,0.30])\n", "learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])\n", "#learn = learn.load('pixel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# learn.lr_find()\n", "# learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr=1e-3\n", "learn.fit_one_cycle(1, lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('enhance_feat')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('enhance_feat')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_img(x, idx=0):\n", " return Image(torch.clamp(x.cpu(),0,1)[idx])\n", "\n", "def plot_x_y_pred(x, pred, y, figsize):\n", " rows=x.shape[0]\n", " fig, axs = plt.subplots(rows,3,figsize=figsize)\n", " for i in range(rows):\n", " make_img(x, i).show(ax=axs[i, 0])\n", " make_img(pred, i).show(ax=axs[i, 1])\n", " make_img(y, i).show(ax=axs[i, 2])\n", " plt.tight_layout() \n", " \n", "def plot_some(learn, do_denorm=True, figsize=None):\n", " x, y = next(iter(learn.data.valid_dl))\n", " y_pred = model(x)\n", " y_pred = y_pred.detach()\n", " x = x.detach()\n", " y = y.detach()\n", " if figsize is None: figsize=y_pred.shape[-2:]\n", " plot_x_y_pred(x[0:4], y_pred[0:4], y[0:4], figsize=figsize)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sz_lr = 64\n", "scale,bs = 4,24\n", "sz_hr = sz_lr*scale\n", "data = get_data(bs, sz_lr, sz_hr)\n", "loss = F.mse_loss\n", "learn = Learner(data, nn.DataParallel(model), loss_func=loss)\n", "learn = learn.load('enhance_feat')\n", "\n", "plot_some(learn, figsize=(256,256))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('pixel')\n", "plot_some(learn, figsize=(256,256))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }