{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This requires to install fastprogress (pip install fastprogress)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_004a import *\n", "from fastprogress import master_bar,progress_bar" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test with training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def fit(epochs, model, loss_fn, opt, data, callbacks=None, metrics=None, pbar=None):\n", " cb_handler = CallbackHandler(callbacks)\n", " cb_handler.on_train_begin()\n", " if pbar is None: pbar = master_bar(range(epochs))\n", "\n", " for epoch in pbar:\n", " model.train()\n", " cb_handler.on_epoch_begin()\n", "\n", " for xb,yb in progress_bar(data.train_dl, parent=pbar):\n", " xb, yb = cb_handler.on_batch_begin(xb, yb)\n", " loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)\n", " if cb_handler.on_batch_end(loss): break\n", "\n", " if hasattr(data,'valid_dl') and data.valid_dl is not None:\n", " model.eval()\n", " with torch.no_grad():\n", " *val_metrics,nums = zip(*[loss_batch(model, xb, yb, loss_fn, cb_handler=cb_handler, metrics=metrics)\n", " for xb,yb in progress_bar(data.valid_dl, parent=pbar)])\n", " val_metrics = [np.sum(np.multiply(val,nums)) / np.sum(nums) for val in val_metrics]\n", "\n", " else: val_metrics=None\n", " if cb_handler.on_epoch_end(val_metrics): break\n", "\n", " cb_handler.on_train_end()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Learner():\n", " \"Object that wraps together some data, a model, a loss function and an optimizer\"\n", "\n", " data:DataBunch\n", " model:nn.Module\n", " opt_fn:Callable=AdamW\n", " loss_fn:Callable=F.cross_entropy\n", " metrics:Collection[Callable]=None\n", " true_wd:bool=True\n", " wd:Floats=1e-6\n", " train_bn:bool=True\n", " path:str = 'models'\n", " callback_fns:Collection[Callable]=None\n", " layer_groups:Collection[nn.Module]=None\n", " def __post_init__(self):\n", " self.path = Path(self.path)\n", " self.path.mkdir(parents=True, exist_ok=True)\n", " self.model = self.model.to(self.data.device)\n", " if not self.layer_groups: self.layer_groups = [self.model]\n", " self.callback_fns = listify(self.callback_fns)\n", " self.callbacks = []\n", "\n", " def fit(self, epochs:int, lr:Floats, wd:Floats=None, callbacks:Collection[Callback]=None):\n", " if wd is None: wd = self.wd\n", " self.create_opt(lr, wd)\n", " if callbacks is None: callbacks = []\n", " callbacks += [cb(self) for cb in self.callback_fns]\n", " pbar = master_bar(range(epochs))\n", " self.recorder = Recorder(self.opt, epochs, self.data.train_dl, pbar)\n", " callbacks = [self.recorder] + self.callbacks + callbacks\n", " fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks, metrics=self.metrics, pbar=pbar)\n", "\n", " def create_opt(self, lr:Floats, wd:Floats=0.):\n", " lrs = listify(lr, self.layer_groups)\n", " opt = self.opt_fn([{'params': trainable_params(l), 'lr':lr} for l,lr in zip(self.layer_groups, lrs)])\n", " self.opt = OptimWrapper(opt, wd=wd, true_wd=self.true_wd)\n", "\n", " \n", " def split(self, split_on):\n", " if isinstance(split_on,Callable): split_on = split_on(self.model)\n", " self.layer_groups = split_model(self.model, split_on)\n", "\n", " def freeze_to(self, n):\n", " for g in self.layer_groups[:n]:\n", " for l in g:\n", " if not self.train_bn or not isinstance(l, bn_types):\n", " for p in l.parameters(): p.requires_grad = False\n", " for g in self.layer_groups[n:]:\n", " for p in g.parameters(): p.requires_grad = True\n", "\n", " def freeze(self):\n", " assert(len(self.layer_groups)>1)\n", " self.freeze_to(-1)\n", " \n", " def unfreeze(self): self.freeze_to(0)\n", " \n", " def save(self, name): torch.save(self.model.state_dict(), self.path/f'{name}.pth')\n", " def load(self, name): self.model.load_state_dict(torch.load(self.path/f'{name}.pth'))\n", "\n", "import nb_004a\n", "nb_004a.Learner = Learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Recorder(Callback):\n", " opt: torch.optim\n", " nb_epoch:int\n", " train_dl: DeviceDataLoader = None\n", " pbar: master_bar = None\n", " \n", " def on_train_begin(self, **kwargs):\n", " self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]\n", "\n", " def on_batch_begin(self, **kwargs):\n", " self.lrs.append(self.opt.lr)\n", " self.moms.append(self.opt.mom)\n", "\n", " def on_backward_begin(self, smooth_loss, **kwargs):\n", " #We record the loss here before any other callback has a chance to modify it.\n", " self.losses.append(smooth_loss)\n", " if self.pbar is not None and hasattr(self.pbar,'child'):\n", " self.pbar.child.comment = f'{smooth_loss:.4f}'\n", "\n", " def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs):\n", " self.nb_batches.append(num_batch)\n", " if last_metrics is not None:\n", " self.val_losses.append(last_metrics[0])\n", " if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])\n", " self.pbar.write(f'{epoch}, {smooth_loss}, {last_metrics}')\n", " else: self.pbar.write(f'{epoch}, {smooth_loss}')\n", "\n", " def plot_lr(self, show_moms=False):\n", " iterations = list(range(len(self.lrs)))\n", " if show_moms:\n", " _, axs = plt.subplots(1,2, figsize=(12,4))\n", " axs[0].plot(iterations, self.lrs)\n", " axs[1].plot(iterations, self.moms)\n", " else: plt.plot(iterations, self.lrs)\n", "\n", " def plot(self, skip_start=10, skip_end=5):\n", " lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]\n", " losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]\n", " _, ax = plt.subplots(1,1)\n", " ax.plot(lrs, losses)\n", " ax.set_xscale('log')\n", "\n", " def plot_losses(self):\n", " _, ax = plt.subplots(1,1)\n", " iterations = list(range(len(self.losses)))\n", " ax.plot(iterations, self.losses)\n", " val_iter = self.nb_batches\n", " val_iter = np.cumsum(val_iter)\n", " ax.plot(val_iter, self.val_losses)\n", "\n", " def plot_metrics(self):\n", " assert len(self.metrics) != 0, \"There is no metrics to plot.\"\n", " _, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))\n", " val_iter = self.nb_batches\n", " val_iter = np.cumsum(val_iter)\n", " axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]\n", " for i, ax in enumerate(axes):\n", " values = [met[i] for met in self.metrics]\n", " ax.plot(val_iter, values)\n", "\n", "import nb_004\n", "nb_004.Recorder = Recorder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class ShowGraph(Callback):\n", " learn:Learner\n", " \n", " def on_epoch_end(self, last_metrics, **kwargs):\n", " if last_metrics is not None:\n", " rec = learn.recorder\n", " iters = list(range(len(rec.losses)))\n", " val_iter = np.array(rec.nb_batches).cumsum()\n", " x_bounds = (0, (rec.nb_epoch - len(rec.nb_batches)) * rec.nb_batches[-1] + len(rec.losses))\n", " y_bounds = (0, max((max(rec.losses), max(rec.val_losses))))\n", " rec.pbar.update_graph([(iters, rec.losses), (val_iter, rec.val_losses)], x_bounds, y_bounds) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATA_PATH = Path('data')\n", "PATH = DATA_PATH/'cifar10'\n", "\n", "data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))\n", "cifar_norm,cifar_denorm = normalize_funcs(data_mean,data_std)\n", "\n", "train_tfms = [flip_lr(p=0.5),\n", " pad(padding=4),\n", " crop(size=32, row_pct=(0,1.), col_pct=(0,1.))]\n", "valid_tfms = []\n", "\n", "bs = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ImageDataset.from_folder(PATH/'train', classes=['airplane','dog'])\n", "valid_ds = ImageDataset.from_folder(PATH/'test', classes=['airplane','dog'])\n", "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=train_tfms, valid_tfm=valid_tfms, num_workers=0, dl_tfms=cifar_norm)\n", "len(data.train_dl), len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pdb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.callback_fns = [ShowGraph]\n", "learn.fit(5,0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }