{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from exp.nb_11 import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Serializing the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=2920)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = 128\n", "bs = 64\n", "\n", "tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]\n", "val_tfms = [make_rgb, CenterCrop(size), np_to_float]\n", "il = ImageList.from_files(path, tfms=tfms)\n", "sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))\n", "ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())\n", "ll.valid.x.tfms = val_tfms\n", "data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(il)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loss_func = LabelSmoothingCrossEntropy()\n", "opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):\n", " phases = create_phases(pct_start)\n", " sched_lr = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))\n", " sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))\n", " return [ParamScheduler('lr', sched_lr),\n", " ParamScheduler('mom', sched_mom)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 3e-3\n", "pct_start = 0.5\n", "cbsched = sched_1cycle(lr, pct_start)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(40, cbsched)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "st = learn.model.state_dict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(st)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "', '.join(st.keys())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "st['10.bias']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mdl_path = path/'models'\n", "mdl_path.mkdir(exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(st, mdl_path/'iw5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3127)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets = datasets.untar_data(datasets.URLs.PETS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pets_path = pets/'images'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "il = ImageList.from_files(pets_path, tfms=tfms)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "il" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def random_splitter(fn, p_valid): return random.random() < p_valid" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "random.seed(42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = il.items[0].name; n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "re.findall(r'^(.*)_\\d+.jpg$', n)[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pet_labeler(fn): return re.findall(r'^(.*)_\\d+.jpg$', fn.name)[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "proc = CategoryProcessor()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ll = label_by_func(sd, pet_labeler, proc_y=proc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "', '.join(proc.vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ll.valid.x.tfms = val_tfms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c_out = len(proc.vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5, cbsched)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom head" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3265)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "st = torch.load(mdl_path/'iw5')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = learn.model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m.load_state_dict(st)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))\n", "m_cut = m[:cut]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xb,yb = get_batch(data.valid_dl, learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred = m_cut(xb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ni = pred.shape[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AdaptiveConcatPool2d(nn.Module):\n", " def __init__(self, sz=1):\n", " super().__init__()\n", " self.output_size = sz\n", " self.ap = nn.AdaptiveAvgPool2d(sz)\n", " self.mp = nn.AdaptiveMaxPool2d(sz)\n", " def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nh = 40\n", "\n", "m_new = nn.Sequential(\n", " m_cut, AdaptiveConcatPool2d(), Flatten(),\n", " nn.Linear(ni*2, data.c_out))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.model = m_new" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5, cbsched)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## adapt_model and gradual unfreezing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3483)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def adapt_model(learn, data):\n", " cut = next(i for i,o in enumerate(learn.model.children())\n", " if isinstance(o,nn.AdaptiveAvgPool2d))\n", " m_cut = learn.model[:cut]\n", " xb,yb = get_batch(data.valid_dl, learn)\n", " pred = m_cut(xb)\n", " ni = pred.shape[1]\n", " m_new = nn.Sequential(\n", " m_cut, AdaptiveConcatPool2d(), Flatten(),\n", " nn.Linear(ni*2, data.c_out))\n", " learn.model = m_new" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)\n", "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adapt_model(learn, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for p in learn.model[0].parameters(): p.requires_grad_(False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(3, sched_1cycle(1e-2, 0.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for p in learn.model[0].parameters(): p.requires_grad_(True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5, cbsched, reset_opt=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch norm transfer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3567)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)\n", "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n", "adapt_model(learn, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def apply_mod(m, f):\n", " f(m)\n", " for l in m.children(): apply_mod(l, f)\n", "\n", "def set_grad(m, b):\n", " if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return\n", " if hasattr(m, 'weight'):\n", " for p in m.parameters(): p.requires_grad_(b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "apply_mod(learn.model, partial(set_grad, b=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(3, sched_1cycle(1e-2, 0.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "apply_mod(learn.model, partial(set_grad, b=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5, cbsched, reset_opt=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pytorch already has an `apply` method we can use:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.model.apply(partial(set_grad, b=False));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Discriminative LR and param groups" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3799)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n", "adapt_model(learn, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def bn_splitter(m):\n", " def _bn_splitter(l, g1, g2):\n", " if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()\n", " elif hasattr(l, 'weight'): g1 += l.parameters()\n", " for ll in l.children(): _bn_splitter(ll, g1, g2)\n", " \n", " g1,g2 = [],[]\n", " _bn_splitter(m[0], g1, g2)\n", " \n", " g2 += m[1:].parameters()\n", " return g1,g2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a,b = bn_splitter(learn.model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(len(a)+len(b), len(list(m.parameters())))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Learner.ALL_CBS" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from types import SimpleNamespace\n", "cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb_types.after_backward" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class DebugCallback(Callback):\n", " _order = 999\n", " def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f\n", " def __call__(self, cb_name):\n", " if cb_name==self.cb_name:\n", " if self.f: self.f(self.run)\n", " else: set_trace()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):\n", " phases = create_phases(pct_start)\n", " sched_lr = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))\n", " for lr in lrs]\n", " sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))\n", " return [ParamScheduler('lr', sched_lr),\n", " ParamScheduler('mom', sched_mom)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "disc_lr_sched = sched_1cycle([0,3e-2], 0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = cnn_learner(xresnet18, data, loss_func, opt_func,\n", " c_out=10, norm=norm_imagenette, splitter=bn_splitter)\n", "\n", "learn.model.load_state_dict(torch.load(mdl_path/'iw5'))\n", "adapt_model(learn, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _print_det(o): \n", " print (len(o.opt.param_groups), o.opt.hypers)\n", " raise CancelTrainException()\n", "\n", "learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(3, disc_lr_sched)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5, disc_lr_sched)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!./notebook2script.py 11a_transfer_learning.ipynb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }