In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
#export
from exp.nb_11 import *

## Serializing the model

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=2920)

In [None]:
path = datasets.untar_data(datasets.URLs.IMAGEWOOF_160)

In [None]:
size = 128
bs = 64

tfms = [make_rgb, RandomResizedCrop(size, scale=(0.35,1)), np_to_float, PilRandomFlip()]
val_tfms = [make_rgb, CenterCrop(size), np_to_float]
il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
ll.valid.x.tfms = val_tfms
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=8)

In [None]:
len(il)

In [None]:
loss_func = LabelSmoothingCrossEntropy()
opt_func = adam_opt(mom=0.9, mom_sqr=0.99, eps=1e-6, wd=1e-2)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

In [None]:
def sched_1cycle(lr, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
 phases = create_phases(pct_start)
 sched_lr = combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
 sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
 return [ParamScheduler('lr', sched_lr),
 ParamScheduler('mom', sched_mom)]

In [None]:
lr = 3e-3
pct_start = 0.5
cbsched = sched_1cycle(lr, pct_start)

In [None]:
learn.fit(40, cbsched)

In [None]:
st = learn.model.state_dict()

In [None]:
type(st)

In [None]:
', '.join(st.keys())

In [None]:
st['10.bias']

In [None]:
mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)

It's also possible to save the whole model, including the architecture, but it gets quite fiddly and we don't recommend it. Instead, just save the parameters, and recreate the model directly.

In [None]:
torch.save(st, mdl_path/'iw5')

## Pets

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3127)

In [None]:
pets = datasets.untar_data(datasets.URLs.PETS)

In [None]:
pets.ls()

In [None]:
pets_path = pets/'images'

In [None]:
il = ImageList.from_files(pets_path, tfms=tfms)

In [None]:
il

In [None]:
#export
def random_splitter(fn, p_valid): return random.random() < p_valid

In [None]:
random.seed(42)

In [None]:
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [None]:
sd

In [None]:
n = il.items[0].name; n

In [None]:
re.findall(r'^(.*)_\d+.jpg$', n)[0]

In [None]:
def pet_labeler(fn): return re.findall(r'^(.*)_\d+.jpg$', fn.name)[0]

In [None]:
proc = CategoryProcessor()

In [None]:
ll = label_by_func(sd, pet_labeler, proc_y=proc)

In [None]:
', '.join(proc.vocab)

In [None]:
ll.valid.x.tfms = val_tfms

In [None]:
c_out = len(proc.vocab)

In [None]:
data = ll.to_databunch(bs, c_in=3, c_out=c_out, num_workers=8)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, norm=norm_imagenette)

In [None]:
learn.fit(5, cbsched)

## Custom head

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3265)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

In [None]:
st = torch.load(mdl_path/'iw5')

In [None]:
m = learn.model

In [None]:
m.load_state_dict(st)

In [None]:
cut = next(i for i,o in enumerate(m.children()) if isinstance(o,nn.AdaptiveAvgPool2d))
m_cut = m[:cut]

In [None]:
xb,yb = get_batch(data.valid_dl, learn)

In [None]:
pred = m_cut(xb)

In [None]:
pred.shape

In [None]:
ni = pred.shape[1]

In [None]:
#export
class AdaptiveConcatPool2d(nn.Module):
 def __init__(self, sz=1):
 super().__init__()
 self.output_size = sz
 self.ap = nn.AdaptiveAvgPool2d(sz)
 self.mp = nn.AdaptiveMaxPool2d(sz)
 def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [None]:
nh = 40

m_new = nn.Sequential(
 m_cut, AdaptiveConcatPool2d(), Flatten(),
 nn.Linear(ni*2, data.c_out))

In [None]:
learn.model = m_new

In [None]:
learn.fit(5, cbsched)

## adapt_model and gradual unfreezing

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3483)

In [None]:
def adapt_model(learn, data):
 cut = next(i for i,o in enumerate(learn.model.children())
 if isinstance(o,nn.AdaptiveAvgPool2d))
 m_cut = learn.model[:cut]
 xb,yb = get_batch(data.valid_dl, learn)
 pred = m_cut(xb)
 ni = pred.shape[1]
 m_new = nn.Sequential(
 m_cut, AdaptiveConcatPool2d(), Flatten(),
 nn.Linear(ni*2, data.c_out))
 learn.model = m_new

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))

In [None]:
adapt_model(learn, data)

In [None]:
for p in learn.model[0].parameters(): p.requires_grad_(False)

In [None]:
learn.fit(3, sched_1cycle(1e-2, 0.5))

In [None]:
for p in learn.model[0].parameters(): p.requires_grad_(True)

In [None]:
learn.fit(5, cbsched, reset_opt=True)

## Batch norm transfer

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3567)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

In [None]:
def apply_mod(m, f):
 f(m)
 for l in m.children(): apply_mod(l, f)

def set_grad(m, b):
 if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
 if hasattr(m, 'weight'):
 for p in m.parameters(): p.requires_grad_(b)

In [None]:
apply_mod(learn.model, partial(set_grad, b=False))

In [None]:
learn.fit(3, sched_1cycle(1e-2, 0.5))

In [None]:
apply_mod(learn.model, partial(set_grad, b=True))

In [None]:
learn.fit(5, cbsched, reset_opt=True)

Pytorch already has an `apply` method we can use:

In [None]:
learn.model.apply(partial(set_grad, b=False));

## Discriminative LR and param groups

[Jump_to lesson 12 video](https://course19.fast.ai/videos/?lesson=12&t=3799)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func, c_out=10, norm=norm_imagenette)

In [None]:
learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

In [None]:
def bn_splitter(m):
 def _bn_splitter(l, g1, g2):
 if isinstance(l, nn.BatchNorm2d): g2 += l.parameters()
 elif hasattr(l, 'weight'): g1 += l.parameters()
 for ll in l.children(): _bn_splitter(ll, g1, g2)
 
 g1,g2 = [],[]
 _bn_splitter(m[0], g1, g2)
 
 g2 += m[1:].parameters()
 return g1,g2

In [None]:
a,b = bn_splitter(learn.model)

In [None]:
test_eq(len(a)+len(b), len(list(m.parameters())))

In [None]:
Learner.ALL_CBS

In [None]:
#export
from types import SimpleNamespace
cb_types = SimpleNamespace(**{o:o for o in Learner.ALL_CBS})

In [None]:
cb_types.after_backward

In [None]:
#export
class DebugCallback(Callback):
 _order = 999
 def __init__(self, cb_name, f=None): self.cb_name,self.f = cb_name,f
 def __call__(self, cb_name):
 if cb_name==self.cb_name:
 if self.f: self.f(self.run)
 else: set_trace()

In [None]:
#export
def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
 phases = create_phases(pct_start)
 sched_lr = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
 for lr in lrs]
 sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
 return [ParamScheduler('lr', sched_lr),
 ParamScheduler('mom', sched_mom)]

In [None]:
disc_lr_sched = sched_1cycle([0,3e-2], 0.5)

In [None]:
learn = cnn_learner(xresnet18, data, loss_func, opt_func,
 c_out=10, norm=norm_imagenette, splitter=bn_splitter)

learn.model.load_state_dict(torch.load(mdl_path/'iw5'))
adapt_model(learn, data)

In [None]:
def _print_det(o): 
 print (len(o.opt.param_groups), o.opt.hypers)
 raise CancelTrainException()

learn.fit(1, disc_lr_sched + [DebugCallback(cb_types.after_batch, _print_det)])

In [None]:
learn.fit(3, disc_lr_sched)

In [None]:
disc_lr_sched = sched_1cycle([1e-3,1e-2], 0.3)

In [None]:
learn.fit(5, disc_lr_sched)

## Export

In [None]:
!./notebook2script.py 11a_transfer_learning.ipynb