{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_003 import *\n", "from torch import Tensor,tensor\n", "from fastprogress import master_bar,progress_bar\n", "from fastprogress.fastprogress import MasterBar, ProgressBar\n", "import re\n", "from typing import Iterator\n", "\n", "Floats = Union[float, Collection[float]]\n", "PBar = Union[MasterBar, ProgressBar]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastprogress.fastprogress as fp2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameters and callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATA_PATH = Path('data')\n", "PATH = DATA_PATH/'cifar10'\n", "\n", "data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))\n", "cifar_norm,cifar_denorm = normalize_funcs(data_mean,data_std)\n", "\n", "tfms = [flip_lr(p=0.5),\n", " pad(padding=4),\n", " crop(size=32, row_pct=(0,1.), col_pct=(0,1.))]\n", "\n", "bs = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ImageDataset.from_folder(PATH/'train', classes=['airplane','dog'])\n", "valid_ds = ImageDataset.from_folder(PATH/'test', classes=['airplane','dog'])\n", "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, num_workers=4, dl_tfms=cifar_norm)\n", "len(data.train_dl), len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 4, 6, 3], num_classes=10, nf=16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting hyperparameters easily" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We want an optimizer with an easy way to set hyperparameters: they're all properties and we define custom setters to handle the different names in pytorch optimizers. We will define a Wrapper for all optimizers within which we will define each parameter's setter functions for setting the values we want. This will allow us to set a default value for each hyperparameter but to also easily edit it while experimenting. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class OptimWrapper():\n", " \"Normalize naming of parameters on wrapped optimizers\"\n", " def __init__(self, opt:optim.Optimizer, wd:float=0., true_wd:bool=False):\n", " \"Create wrapper for `opt` and optionally (`true_wd`) set weight decay `wd`\"\n", " self.opt,self.true_wd = opt,true_wd\n", " self.opt_keys = list(self.opt.param_groups[0].keys())\n", " self.opt_keys.remove('params')\n", " self.read_defaults()\n", " self._wd = wd\n", " \n", " #Pytorch optimizer methods\n", " def step(self)->None:\n", " \"Performs a single optimization step \"\n", " # weight decay outside of optimizer step (AdamW)\n", " if self.true_wd:\n", " for pg in self.opt.param_groups:\n", " for p in pg['params']: p.data.mul_(1 - self._wd*pg['lr'])\n", " self.set_val('weight_decay', 0)\n", " self.opt.step()\n", " \n", " def zero_grad(self)->None: \n", " \"Clears the gradients of all optimized `Tensor`s\"\n", " self.opt.zero_grad()\n", " \n", " #Hyperparameters as properties\n", " @property\n", " def lr(self)->float: \n", " \"Learning rate\"\n", " return self._lr\n", "\n", " @lr.setter\n", " def lr(self, val:float)->None: self._lr = self.set_val('lr', val)\n", " \n", " @property\n", " def mom(self)->float: \n", " \"Momentum if present on wrapped opt, else betas\"\n", " return self._mom\n", "\n", " @mom.setter\n", " def mom(self, val:float)->None:\n", " \"Momentum if present on wrapped opt, else betas\"\n", " if 'momentum' in self.opt_keys: self.set_val('momentum', val)\n", " elif 'betas' in self.opt_keys: self.set_val('betas', (val, self._beta))\n", " self._mom = val\n", " \n", " @property\n", " def beta(self)->float:\n", " \"Beta if present on wrapped opt, else it's alpha\"\n", " return self._beta\n", "\n", " @beta.setter\n", " def beta(self, val:float)->None:\n", " \"Beta if present on wrapped opt, else it's alpha\"\n", " if 'betas' in self.opt_keys: self.set_val('betas', (self._mom,val))\n", " elif 'alpha' in self.opt_keys: self.set_val('alpha', val)\n", " self._beta = val\n", " \n", " @property\n", " def wd(self)->float: \n", " \"Weight decay for wrapped opt\"\n", " return self._wd\n", "\n", " @wd.setter\n", " def wd(self, val:float)->None:\n", " \"Weight decay for wrapped opt\"\n", " if not self.true_wd: self.set_val('weight_decay', val)\n", " self._wd = val\n", " \n", " #Helper functions\n", " def read_defaults(self):\n", " \"Reads in the default params from the wrapped optimizer\"\n", " self._beta = None\n", " if 'lr' in self.opt_keys: self._lr = self.opt.param_groups[0]['lr']\n", " if 'momentum' in self.opt_keys: self._mom = self.opt.param_groups[0]['momentum']\n", " if 'alpha' in self.opt_keys: self._beta = self.opt.param_groups[0]['alpha']\n", " if 'betas' in self.opt_keys: self._mom,self._beta = self.opt.param_groups[0]['betas']\n", " if 'weight_decay' in self.opt_keys: self._wd = self.opt.param_groups[0]['weight_decay']\n", " \n", " def set_val(self, key:str, val:Any):\n", " \"Set parameter on wrapped optimizer\"\n", " for pg in self.opt.param_groups: pg[key] = val\n", " return val" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt_fn = partial(optim.Adam, betas=(0.95,0.99))\n", "opt = OptimWrapper(opt_fn(model.parameters(), 1e-2))\n", "opt.lr, opt.mom, opt.wd, opt.beta" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt.lr=0.2\n", "opt.lr, opt.mom, opt.wd, opt.beta" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that it's easy to set and change the HP in the optimizer, we need a scheduler to change it. To keep the training loop as readable as possible we don't want to handle all of this stuff inside it so we'll use callbacks. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Callback():\n", " \"Base class for callbacks that want to record values, dynamically change learner params, etc\"\n", " def on_train_begin(self, **kwargs:Any)->None: \n", " \"To initialize constants in the callback.\"\n", " pass\n", " def on_epoch_begin(self, **kwargs:Any)->None:\n", " \"At the beginning of each epoch\"\n", " pass\n", " def on_batch_begin(self, **kwargs:Any)->None: \n", " \"\"\"To set HP before the step is done.\n", " Returns xb, yb (which can allow us to modify the input at that step if needed)\"\"\"\n", " pass\n", " def on_loss_begin(self, **kwargs:Any)->None:\n", " \"\"\"Called after the forward pass but before the loss has been computed.\n", " Returns the output (which can allow us to modify it)\"\"\"\n", " pass\n", " def on_backward_begin(self, **kwargs:Any)->None:\n", " \"\"\"Called after the forward pass and the loss has been computed, but before the back propagation.\n", " Returns the loss (which can allow us to modify it, for instance for reg functions)\"\"\"\n", " pass\n", " def on_backward_end(self, **kwargs:Any)->None:\n", " \"\"\"Called after the back propagation had been done (and the gradients computed) but before the step of the optimizer.\n", " Useful for true weight decay in AdamW\"\"\"\n", " pass\n", " def on_step_end(self, **kwargs:Any)->None:\n", " \"Called after the step of the optimizer but before the gradients are zeroed (not sure this one is useful)\"\n", " pass\n", " def on_batch_end(self, **kwargs:Any)->None:\n", " \"Called at the end of the batch\"\n", " pass\n", " def on_epoch_end(self, **kwargs:Any)->bool:\n", " \"Called at the end of an epoch\"\n", " return False\n", " def on_train_end(self, **kwargs:Any)->None:\n", " \"Useful for cleaning up things and saving files/models\"\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To be more convenient and make the code of the training loop cleaner, we'll create a class to handle the callbacks. It will keep track of everything the training loop sends it, then pack it in the kwargs of each callback. This way, all the callbacks can access things like the epoch number, the last loss etc... Notice also that callbacks that are given a parameter have the opportunity to edit that parameter and return a new value.\n", "\n", "Helper class that computes a moving average of the values sent to it. Will be useful for the smooth_loss but can smoothen anything." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SmoothenValue():\n", " \"Creates a smooth moving average for a value (loss, etc)\"\n", " def __init__(self, beta:float)->None:\n", " \"Create smoother for value, beta should be 0None:\n", " \"Add current value to calculate updated smoothed value \"\n", " self.n += 1\n", " self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val\n", " self.smooth = self.mov_avg / (1 - self.beta ** self.n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "TensorOrNumber = Union[Tensor,Number]\n", "CallbackList = Collection[Callback]\n", "MetricsList = Collection[TensorOrNumber]\n", "TensorOrNumList = Collection[TensorOrNumber]\n", "MetricFunc = Callable[[Tensor,Tensor],TensorOrNumber]\n", "MetricFuncList = Collection[MetricFunc]\n", "\n", "\n", "def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0}\n", "\n", "@dataclass\n", "class CallbackHandler():\n", " \"Manages all of the registered callback objects, beta is for smoothing loss\"\n", " callbacks:CallbackList\n", " beta:float=0.98\n", " \n", " def __post_init__(self)->None:\n", " \"InitInitializeitialize smoother and learning stats\"\n", " self.smoothener = SmoothenValue(self.beta)\n", " self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()\n", " \n", " def __call__(self, cb_name, **kwargs)->None:\n", " \"Call through to all of the callback handlers\"\n", " return [getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs) for cb in self.callbacks]\n", " \n", " def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:\n", " \"About to start learning\"\n", " self.state_dict = _get_init_state()\n", " self.state_dict['n_epochs'],self.state_dict['pbar'],self.state_dict['metrics'] = epochs,pbar,metrics\n", " self('train_begin')\n", " \n", " def on_epoch_begin(self)->None: \n", " \"Handle new epoch\"\n", " self.state_dict['num_batch'] = 0\n", " self('epoch_begin')\n", " \n", " def on_batch_begin(self, xb:Tensor, yb:Tensor)->None:\n", " \"Handle new batch `xb`,`yb`\"\n", " self.state_dict['last_input'], self.state_dict['last_target'] = xb, yb\n", " for cb in self.callbacks:\n", " a = cb.on_batch_begin(**self.state_dict)\n", " if a is not None: self.state_dict['last_input'], self.state_dict['last_target'] = a\n", " return self.state_dict['last_input'], self.state_dict['last_target']\n", " \n", " def on_loss_begin(self, out:Tensor)->None:\n", " \"Handle start of loss calculation with model output `out`\"\n", " self.state_dict['last_output'] = out\n", " for cb in self.callbacks:\n", " a = cb.on_loss_begin(**self.state_dict)\n", " if a is not None: self.state_dict['last_output'] = a\n", " return self.state_dict['last_output']\n", " \n", " def on_backward_begin(self, loss:Tensor)->None:\n", " \"Handle gradient calculation on `loss`\"\n", " self.smoothener.add_value(loss.detach())\n", " self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth\n", " for cb in self.callbacks:\n", " a = cb.on_backward_begin(**self.state_dict)\n", " if a is not None: self.state_dict['last_loss'] = a\n", " return self.state_dict['last_loss']\n", " \n", " def on_backward_end(self)->None: \n", " \"Handle end of gradient calc\"\n", " self('backward_end')\n", " def on_step_end(self)->None: \n", " \"Handle end of optimization step\"\n", " self('step_end')\n", " \n", " def on_batch_end(self, loss:Tensor)->None:\n", " \"Handle end of processing one batch with `loss`\"\n", " self.state_dict['last_loss'] = loss\n", " stop = np.any(self('batch_end'))\n", " self.state_dict['iteration'] += 1\n", " self.state_dict['num_batch'] += 1\n", " return stop\n", " \n", " def on_epoch_end(self, val_metrics:MetricsList)->bool:\n", " \"Epoch is done, process `val_metrics`\"\n", " self.state_dict['last_metrics'] = val_metrics\n", " stop = np.any(self('epoch_end'))\n", " self.state_dict['epoch'] += 1\n", " return stop\n", " \n", " def on_train_end(self, exception:Union[bool,Exception])->None: \n", " \"Handle end of training, `exception` is an `Exception` or False if no exceptions during training\"\n", " self('train_end', exception=exception)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The idea is to have a callback handler between every line of the training loop, that way every callback we need to add will be treated there and not inside. We also add metrics right after calculating the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "\n", "OptMetrics = Optional[Collection[Any]]\n", "OptLossFunc = Optional[LossFunction]\n", "OptCallbackHandler = Optional[CallbackHandler]\n", "OptOptimizer = Optional[optim.Optimizer]\n", "OptCallbackList = Optional[CallbackList]\n", "\n", "\n", "def loss_batch(model:Model, xb:Tensor, yb:Tensor, loss_fn:OptLossFunc=None, \n", " opt:OptOptimizer=None, cb_handler:OptCallbackHandler=None, \n", " metrics:OptMetrics=None)->Tuple[Union[Tensor,int,float,str]]:\n", " \"Calculate loss for a batch, calculate metrics, call out to callbacks as necessary\"\n", " if cb_handler is None: cb_handler = CallbackHandler([])\n", " if not is_listy(xb): xb = [xb]\n", " if not is_listy(yb): yb = [yb]\n", " out = model(*xb)\n", " out = cb_handler.on_loss_begin(out)\n", " if not loss_fn: return out.detach(),yb[0].detach()\n", " loss = loss_fn(out, *yb)\n", " mets = [f(out,*yb).detach().cpu() for f in metrics] if metrics is not None else []\n", " \n", " if opt is not None:\n", " loss = cb_handler.on_backward_begin(loss)\n", " loss.backward()\n", " cb_handler.on_backward_end()\n", " opt.step()\n", " cb_handler.on_step_end()\n", " opt.zero_grad()\n", " \n", " return (loss.detach().cpu(),) + tuple(mets) + (yb[0].shape[0],)\n", "\n", "\n", "def validate(model:Model, dl:DataLoader, loss_fn:OptLossFunc=None, \n", " metrics:OptMetrics=None, cb_handler:OptCallbackHandler=None, \n", " pbar:Optional[PBar]=None)->Iterator[Tuple[Union[Tensor,int],...]]:\n", " \"Calculate loss and metrics for the validation set\"\n", " model.eval()\n", " with torch.no_grad():\n", " return zip(*[loss_batch(model, xb, yb, loss_fn, cb_handler=cb_handler, metrics=metrics)\n", " for xb,yb in progress_bar(dl, parent=pbar)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def fit(epochs:int, model:Model, loss_fn:LossFunction, opt:optim.Optimizer, \n", " data:DataBunch, callbacks:OptCallbackList=None, metrics:OptMetrics=None)->None:\n", " \"Fit the `model` on `data` and learn using `loss` and `opt`\"\n", " cb_handler = CallbackHandler(callbacks)\n", " pbar = master_bar(range(epochs))\n", " cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)\n", "\n", " exception=False\n", " try:\n", " for epoch in pbar:\n", " model.train()\n", " cb_handler.on_epoch_begin()\n", "\n", " for xb,yb in progress_bar(data.train_dl, parent=pbar):\n", " xb, yb = cb_handler.on_batch_begin(xb, yb)\n", " loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)\n", " if cb_handler.on_batch_end(loss): break\n", "\n", " if hasattr(data,'valid_dl') and data.valid_dl is not None:\n", " *val_metrics,nums = validate(model, data.valid_dl, loss_fn=loss_fn,\n", " cb_handler=cb_handler, metrics=metrics,pbar=pbar)\n", " nums = np.array(nums, dtype=np.float32)\n", " val_metrics = [(torch.stack(val).cpu().numpy() * nums).sum() / nums.sum()\n", " for val in val_metrics]\n", "\n", " else: val_metrics=None\n", " if cb_handler.on_epoch_end(val_metrics): break\n", " except Exception as e:\n", " exception = e\n", " raise e\n", " finally: cb_handler.on_train_end(exception)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First callback: it records the important values, updates the progress bar and prints out the epoch and validation loss as the training progresses. The important values we save during training such as losses and hyper-parameters will be used for future plots (lr_finder, plot of the LR/mom schedule). We will also add the plotting tools that will be used over and over again when training models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_camel_re1 = re.compile('(.)([A-Z][a-z]+)')\n", "_camel_re2 = re.compile('([a-z0-9])([A-Z])')\n", "def camel2snake(name:str)->str:\n", " s1 = re.sub(_camel_re1, r'\\1_\\2', name)\n", " return re.sub(_camel_re2, r'\\1_\\2', s1).lower()\n", "\n", "@dataclass\n", "class LearnerCallback(Callback):\n", " \"Base class for creating callbacks for the `Learner`\"\n", " learn: Learner\n", " def __post_init__(self):\n", " if self.cb_name: setattr(self.learn, self.cb_name, self)\n", "\n", " @property\n", " def cb_name(self): return camel2snake(self.__class__.__name__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Recorder(LearnerCallback):\n", " \"A `LearnerCallback` that records epoch,loss,opt and metric data during training\"\n", " def __init__(self, learn:Learner):\n", " super().__init__(learn)\n", " self.opt = self.learn.opt\n", " self.train_dl = self.learn.data.train_dl\n", " \n", " def on_train_begin(self, pbar:PBar, metrics:MetricFuncList, **kwargs:Any)->None:\n", " \"Initialize recording status at beginning of training\"\n", " self.pbar = pbar\n", " self.names = ['epoch', 'train loss', 'valid loss'] + [fn.__name__ for fn in metrics]\n", " self.pbar.write(' '.join(self.names))\n", " self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]\n", "\n", " def on_batch_begin(self, **kwargs:Any)->None:\n", " \"Record learning rate and momentum at beginning of batch\"\n", " self.lrs.append(self.opt.lr)\n", " self.moms.append(self.opt.mom)\n", "\n", " def on_backward_begin(self, smooth_loss:Tensor, **kwargs:Any)->None:\n", " \"Record the loss before any other callback has a chance to modify it.\"\n", " self.losses.append(smooth_loss)\n", " if self.pbar is not None and hasattr(self.pbar,'child'):\n", " self.pbar.child.comment = f'{smooth_loss:.4f}'\n", "\n", " def on_epoch_end(self, epoch:int, num_batch:int, smooth_loss:Tensor, \n", " last_metrics=MetricsList, **kwargs:Any)->bool:\n", " \"Save epoch info: num_batch, smooth_loss, metrics\"\n", " self.nb_batches.append(num_batch)\n", " if last_metrics is not None:\n", " self.val_losses.append(last_metrics[0])\n", " if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])\n", " self.format_stats([epoch, smooth_loss] + last_metrics)\n", " else: self.format_stats([epoch, smooth_loss])\n", " return False\n", "\n", " def format_stats(self, stats:TensorOrNumList)->None:\n", " str_stats = []\n", " for name,stat in zip(self.names,stats):\n", " t = str(stat) if isinstance(stat, int) else f'{stat:.6f}'\n", " t += ' ' * (len(name) - len(t))\n", " str_stats.append(t)\n", " self.pbar.write(' '.join(str_stats))\n", " \n", " def plot_lr(self, show_moms=False)->None:\n", " \"Plot learning rate, `show_moms` to include momentum\"\n", " iterations = list(range(len(self.lrs)))\n", " if show_moms:\n", " _, axs = plt.subplots(1,2, figsize=(12,4))\n", " axs[0].plot(iterations, self.lrs)\n", " axs[1].plot(iterations, self.moms)\n", " else: plt.plot(iterations, self.lrs)\n", "\n", " def plot(self, skip_start:int=10, skip_end:int=5)->None:\n", " \"Plot learning rate and losses, trimmed between `skip_start` and `skip_end`\"\n", " lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]\n", " losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]\n", " _, ax = plt.subplots(1,1)\n", " ax.plot(lrs, losses)\n", " ax.set_xscale('log')\n", "\n", " def plot_losses(self)->None:\n", " \"Plot training and validation losses\"\n", " _, ax = plt.subplots(1,1)\n", " iterations = list(range(len(self.losses)))\n", " ax.plot(iterations, self.losses)\n", " val_iter = self.nb_batches\n", " val_iter = np.cumsum(val_iter)\n", " ax.plot(val_iter, self.val_losses)\n", "\n", " def plot_metrics(self)->None:\n", " \"Plot metrics collected during training\"\n", " assert len(self.metrics) != 0, \"There are no metrics to plot.\"\n", " _, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))\n", " val_iter = self.nb_batches\n", " val_iter = np.cumsum(val_iter)\n", " axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]\n", " for i, ax in enumerate(axes):\n", " values = [met[i] for met in self.metrics]\n", " ax.plot(val_iter, values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy(out:Tensor, yb:Tensor)->TensorOrNumber:\n", " \"Calculate percentage of 1-hot `out` correctly predicted in `yb`\"\n", " preds = torch.argmax(out, dim=1)\n", " return (preds==yb).float().mean()\n", "\n", "AdamW = partial(optim.Adam, betas=(0.9,0.99))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Learner():\n", " \"\"\"Trains `module` with `data` using `loss_fn` and `opt_fn`, collects `metrics` along the way\n", " `true_wd` along with `wd` turn on weight decay, `path` specifies where models are stored\n", " `callback_fns` is used to add custom callbacks beyond Recorder which is added by default\"\"\"\n", " data:DataBunch\n", " model:nn.Module\n", " opt_fn:Callable=AdamW\n", " loss_fn:Callable=F.cross_entropy\n", " metrics:Collection[Callable]=None\n", " true_wd:bool=True\n", " wd:Floats=1e-2\n", " path:str = 'models'\n", " callback_fns:Collection[Callable]=None\n", " callbacks:Collection[Callback]=field(default_factory=list)\n", " def __post_init__(self):\n", " \"Sets up internal learner variables\"\n", " self.path = Path(self.path)\n", " self.metrics=listify(self.metrics)\n", " self.path.mkdir(parents=True, exist_ok=True)\n", " self.model = self.model.to(self.data.device)\n", " self.callbacks = listify(self.callbacks)\n", " self.callback_fns = [Recorder] + listify(self.callback_fns)\n", "\n", " def fit(self, epochs:int, lr:Optional[Floats], wd:Optional[Floats]=None, callbacks:OptCallbackList=None)->None:\n", " \"Fit the model in this learner with `lr` learning rate and `wd` weight decay\"\n", " if wd is None: wd = self.wd\n", " self.create_opt(lr, wd)\n", " callbacks = [cb(self) for cb in self.callback_fns] + listify(callbacks)\n", " fit(epochs, self.model, self.loss_fn, self.opt, self.data, metrics=self.metrics,\n", " callbacks=self.callbacks+callbacks)\n", "\n", " def create_opt(self, lr:Floats, wd:Floats=0.)->None:\n", " \"Binds a new optimizer each time `fit` is called with `lr` learning rate and `wd` weight decay\"\n", " self.opt = OptimWrapper(self.opt_fn(self.model.parameters(),lr))\n", " \n", " def save(self, name:PathOrStr)->None:\n", " \"Save the model bound to this learner in the `path` folder with `name`\"\n", " torch.save(self.model.state_dict(), self.path/f'{name}.pth')\n", " def load(self, name:PathOrStr): \n", " \"Load the model bound to this learner with the `name` model params in the `path` folder\"\n", " self.model.load_state_dict(torch.load(self.path/f'{name}.pth'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "metrics=[accuracy]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model, metrics=metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(1,0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot_losses()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1cycle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now we build a 1cycle scheduler to train our network. To learn more about the 1cycle technique for training neural networks check out Leslie Smith's [paper](https://arxiv.org/pdf/1803.09820.pdf) and for a more graphical and intuitive explanation check out Sylvain Gugger's [post](https://sgugger.github.io/the-1cycle-policy.html).\n", "\n", "We will first define some annealing functions as options to describe how to progressively move a parameter from a start value to an end value. We will also define a Stepper function that will apply our annealing function (default is linear) to our learning rate and momentum to get a learning rate and momentum value for each step in the training.\n", "\n", "We will then build a callback that actually implements our one cycle policy and changes the parameters accordingly.\n", "\n", "The one cycle policy has three steps:\n", "\n", "1. We progressively increase our learning rate from *lr_max/div_factor* to *lr_max* and at the same time we progressively decrease our momentum from *mom_max* to *mom_min*.\n", "2. We do the exact opposite: we progressively decrease our learning rate from *lr_max* to *lr_max/div_factor* and at the same time we progressively increase our momentum from *mom_min* to *mom_max*. \n", "3. We further decrease our learning rate from *lr_max/div_factor* to *lr_max/(div_factor x 100)* and we keep momentum steady at *mom_max*.\n", "\n", "We usually do steps 1 and 2 for an equal amount of iterations that together make ~90% of total iterations (in this case, we chose 45% each totalling 90%). The remaining iterations are used for step 3.\n", "\n", "Note: \n", "Each of these transitions (i.e. how do we get from one value to another) is described by the annealing function of choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "\n", "AnnealFunc = Callable[[Number,Number,float], Number]\n", "def annealing_no(start:Number, end:Number, pct:float)->Number: \n", " \"No annealing, always return `start`\"\n", " return start\n", "def annealing_linear(start:Number, end:Number, pct:float)->Number: \n", " \"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0\"\n", " return start + pct * (end-start)\n", "def annealing_exp(start:Number, end:Number, pct:float)->Number: \n", " \"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0\"\n", " return start * (end/start) ** pct\n", "def annealing_cos(start:Number, end:Number, pct:float)->Number:\n", " \"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0\"\n", " cos_out = np.cos(np.pi * pct) + 1\n", " return end + (start-end)/2 * cos_out\n", " \n", "def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number: \n", " \"Helper function for `anneal_poly`\"\n", " return end + (start-end) * (1-pct)**degree\n", "def annealing_poly(degree:Number)->Number: \n", " \"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0\"\n", " return functools.partial(do_annealing_poly, degree=degree)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import functools" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "annealings = \"NO LINEAR COS EXP POLY\".split()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a = np.arange(0, 100)\n", "p = np.linspace(0.01,1,100)\n", "\n", "fns = [annealing_no, annealing_linear, annealing_cos, annealing_exp, annealing_poly(0.8)]\n", "for fn, t in zip(fns, annealings):\n", " plt.plot(a, [fn(2, 1e-2, o) for o in p], label=t)\n", "plt.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def is_tuple(x:Any)->bool: return isinstance(x, tuple)\n", "StartOptEnd=Union[float,Tuple[float,float]]\n", "class Stepper():\n", " \"Used to \\\"step\\\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func` (defaults to linear)\"\n", " def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):\n", " self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)\n", " self.n_iter = n_iter\n", " if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no\n", " else: self.func = func\n", " self.n = 0\n", " \n", " def step(self)->Number:\n", " \"Return next value along annealed schedule\"\n", " self.n += 1\n", " return self.func(self.start, self.end, self.n/self.n_iter)\n", " \n", " @property\n", " def is_done(self)->bool:\n", " \"Schedule completed\"\n", " return self.n >= self.n_iter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class OneCycleScheduler(Callback):\n", " \"Manages 1-Cycle style traing as outlined in Leslie Smith's [paper](https://arxiv.org/pdf/1803.09820.pdf)\"\n", " learn:Learner\n", " lr_max:float\n", " moms:Floats=(0.95,0.85)\n", " div_factor:float=25.\n", " pct_start:float=0.5\n", " \n", " def __post_init__(self): self.moms=tuple(listify(self.moms,2))\n", "\n", " def steps(self, *steps_cfg:StartOptEnd):\n", " \"Build anneal schedule for all of the parameters\"\n", " return [Stepper(step, n_iter, func=func)\n", " for (step,(n_iter,func)) in zip(steps_cfg, self.phases)]\n", "\n", " def on_train_begin(self, n_epochs:int, **kwargs:Any)->None:\n", " \"Initialize our optimization params based on our annealing schedule\"\n", " n = len(self.learn.data.train_dl) * n_epochs\n", " a1 = int(n * self.pct_start)\n", " a2 = n-a1\n", " self.phases = ((a1, annealing_linear), (a2, annealing_cos))\n", " low_lr = self.lr_max/self.div_factor\n", " self.lr_scheds = self.steps((low_lr, self.lr_max), (self.lr_max, low_lr/1e4))\n", " self.mom_scheds = self.steps(self.moms, (self.moms[1], self.moms[0]))\n", " self.opt = self.learn.opt\n", " self.opt.lr,self.opt.mom = self.lr_scheds[0].start,self.mom_scheds[0].start\n", " self.idx_s = 0\n", " \n", " def on_batch_end(self, **kwargs:Any)->None:\n", " \"Take one step forward on the annealing schedule for the optim params\"\n", " if self.idx_s >= len(self.lr_scheds): return Trrue\n", " self.opt.lr = self.lr_scheds[self.idx_s].step()\n", " self.opt.mom = self.mom_scheds[self.idx_s].step()\n", " # when the current schedule is complete we move onto the next \n", " # schedule. (in 1-cycle there are two schedules)\n", " if self.lr_scheds[self.idx_s].is_done:\n", " self.idx_s += 1\n", "\n", "def one_cycle_scheduler(lr_max:float, **kwargs:Any)->OneCycleScheduler:\n", " return partial(OneCycleScheduler, lr_max=lr_max, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "sched = one_cycle_scheduler(0.1, pct_start=0.3, div_factor=5, moms=[0.95,0.85])\n", "learn = Learner(data, model, metrics=metrics, callback_fns=sched)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(1,0.1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def fit_one_cycle(learn:Learner, cyc_len:int, max_lr:float, moms:Tuple[float,float]=(0.95,0.85),\n", " div_factor:float=10., pct_start:float=0.5, wd:Optional[float]=None):\n", " \"Fits a model following the 1cycle policy\"\n", " cbs = [OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor,\n", " pct_start=pct_start)]\n", " learn.fit(cyc_len, max_lr, wd=wd, callbacks=cbs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model, metrics=metrics)\n", "fit_one_cycle(learn, 1, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More generally, an API that allows you to create your own schedules." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class TrainingPhase():\n", " \"Schedule lr,mom according to `lr_anneal` and `mom_anneal` across a `length` schedule\"\n", " length:int\n", " lrs:Floats\n", " moms:Floats\n", " lr_anneal:Callable=None\n", " mom_anneal:Callable=None\n", " \n", " def __post_init__(self)->None:\n", " self.lr_step = Stepper(lrs, length, lr_anneal)\n", " self.mom_step = Stepper(moms, length, mom_anneal)\n", "\n", "@dataclass\n", "class GeneralScheduler(Callback):\n", " \"Schedule multiple `TrainingPhase` for a `learner`\"\n", " learn:Learner\n", " phases:Collection[TrainingPhase]\n", " \n", " def on_train_begin(self, n_epochs:int, **kwargs:Any)->None:\n", " \"Initialize our lr and mom schedules for training\"\n", " self.lr_scheds = [p.lr_step for p in self.phases]\n", " self.mom_scheds = [p.mom_step for p in self.phases]\n", " self.opt = self.learn.opt\n", " self.opt.lr,self.opt.mom = self.lr_scheds[0].start,self.mom_scheds[0].start\n", " self.idx_s = 0\n", " \n", " def on_batch_end(self, **kwargs:Any)->None:\n", " \"Take a step in lr,mom sched, start next sched when current is complete\"\n", " if self.idx_s >= len(self.lr_scheds): return True\n", " self.opt.lr = self.lr_scheds[self.idx_s].step()\n", " self.opt.mom = self.mom_scheds[self.idx_s].step()\n", " if self.lr_scheds[self.idx_s].is_done:\n", " self.idx_s += 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LR Finder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LRFinder(LearnerCallback):\n", " \"Explore lr vs loss relationship for a learner\"\n", " def __init__(self, learn:Learner, start_lr:float=1e-5, end_lr:float=10, num_it:int=200)->None:\n", " \"Initialize schedule of learning rates\"\n", " super().__init__(learn)\n", " self.data = learn.data\n", " self.sched = Stepper((start_lr, end_lr), num_it, annealing_exp)\n", " #To avoid validating if the train_dl has less than num_it batches, we put aside the valid_dl and remove it\n", " #during the call to fit.\n", " self.valid_dl = learn.data.valid_dl\n", " self.data.valid_dl = None\n", " \n", " def on_train_begin(self, **kwargs:Any)->None:\n", " \"init optimizer and learn params\"\n", " self.learn.save('tmp')\n", " self.opt = self.learn.opt\n", " self.opt.lr = self.sched.start\n", " self.stop,self.best_loss = False,0.\n", " \n", " def on_batch_end(self, iteration:int, smooth_loss:TensorOrNumber, **kwargs:Any)->None:\n", " \"Determine if loss has runaway and we should stop\"\n", " if iteration==0 or smooth_loss < self.best_loss: self.best_loss = smooth_loss\n", " self.opt.lr = self.sched.step()\n", " if self.sched.is_done or smooth_loss > 4*self.best_loss:\n", " #We use the smoothed loss to decide on the stopping since it's less shaky.\n", " self.stop=True\n", " return True\n", " \n", " def on_epoch_end(self, **kwargs:Any)->None: \n", " \"Tell Learner if we need to stop\"\n", " return self.stop\n", " \n", " def on_train_end(self, **kwargs:Any)->None:\n", " \"Cleanup learn model weights disturbed during LRFind exploration\"\n", " # restore the valid_dl we turned of on `__init__`\n", " self.data.valid_dl = self.valid_dl\n", " self.learn.load('tmp')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def lr_find(learn:Learner, start_lr:float=1e-5, end_lr:float=10, num_it:int=100, **kwargs:Any):\n", " \"Explore lr from `start_lr` to `end_lr` over `num_it` iterations of `learn`\"\n", " cb = LRFinder(learn, start_lr, end_lr, num_it)\n", " a = int(np.ceil(num_it/len(learn.data.train_dl)))\n", " learn.fit(a, start_lr, callbacks=[cb], **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model, metrics=metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr_find(learn)\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(2, 5e-3, callbacks=OneCycleScheduler(learn, 0.1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot_losses()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot_metrics()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Show graph" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ShowGraph(LearnerCallback):\n", " \"Updates a graph of learner stats and metrics after each epoch\"\n", " def on_epoch_end(self, n_epochs:int, last_metrics:MetricsList, **kwargs)->bool:\n", " \"If we have metrics plot them in our pbar graph\"\n", " if last_metrics is not None:\n", " rec = learn.recorder\n", " iters = list(range(len(rec.losses)))\n", " val_iter = np.array(rec.nb_batches).cumsum()\n", " x_bounds = (0, (n_epochs - len(rec.nb_batches)) * rec.nb_batches[-1] + len(rec.losses))\n", " y_bounds = (0, max((max(tensor(rec.losses)), max(tensor(rec.val_losses)))))\n", " rec.pbar.update_graph([(iters, rec.losses), (val_iter, rec.val_losses)], x_bounds, y_bounds) \n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model, metrics=metrics, callback_fns=ShowGraph)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(3, 5e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Eye of Sauron" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To grasp the potential of callbacks, here's a full example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class EyeOfSauron(Callback):\n", "\n", "def __init__(self, learn)->None:\n", "#By passing the learner, this callback will have access to everything:\n", "#All the inputs/outputs as they go, the losses, but also the data loaders, the optimizer.\n", "self.learn = learn\n", "\n", "#At any time:\n", "#Changing self.learn.data.train_dl or self.data.valid_dl will change them inside the fit function\n", "#(we just need to pass the data object to the fit function and not data.train_dl/data.valid_dl)\n", "#Changing self.learn.opt.opt (We have an HPOptimizer on top of the actual optimizer) will change it \n", "#inside the fit function.\n", "#Changing self.learn.data or self.learn.opt directly WILL NOT change the data or the optimizer inside the fit function.\n", "\n", "#In any of the callbacks you can unpack in the kwargs:\n", "#- n_epochs, contains the number of epochs the training will take in total\n", "#- epoch, contains the number of the current\n", "#- iteration, contains the number of iterations done since the beginning of training\n", "#- num_batch, contains the number of the batch we're at in the dataloader\n", "#- last_input, contains the last input that got through the model (eventually updated by a callback)\n", "#- last_target, contains the last target that gor through the model (eventually updated by a callback)\n", "#- last_output, contains the last output spitted by the model (eventually updated by a callback)\n", "#- last_loss, contains the last loss computed (eventually updated by a callback)\n", "#- smooth_loss, contains the smoothed version of the loss\n", "#- last_metrics, contains the last validation loss and emtrics computed\n", "#- pbar, the progress bar\n", "\n", "def on_train_begin(self, **kwargs)->None:\n", "#Here we can initiliaze anything we need. \n", "self.opt = self.learn.opt\n", "#The optimizer has now been initialized. We can change any hyper-parameters by typing\n", "#self.opt.lr = new_lr, self.opt.mom = new_mom, self.opt.wd = new_wd or self.opt.beta = new_beta\n", "\n", "def on_epoch_begin(self, **kwargs)->None: pass\n", "#This is not technically useful since we have on_train_begin for epoch 0 and on_epoch_end for all the other epochs\n", "#yet it makes writing code that needs to be done at the beginning of every epoch easy and more readable.\n", "\n", "def on_batch_begin(self, **kwargs)->None: pass\n", "#Here is the perfect place to prepare everything before the model is called.\n", "#Example: change the values of the hyperparameters (if we don't do it on_batch_end instead)\n", "\n", "#If we return something, that will be the new value for xb,yb. \n", "\n", "def on_loss_begin(self, **kwargs)->None: pass\n", "#Here is the place to run some code that needs to be executed after the output has been computed but before the\n", "#loss computation.\n", "#Example: putting the output back in FP32 when training in mixed precision.\n", "\n", "#If we return something, that will be the new value for the output.\n", "\n", "def on_backward_begin(self, **kwargs)->None: pass\n", "#Here is the place to run some code that needs to be executed after the loss has been computed but before the\n", "#gradient computation.\n", "#Example: reg_fn in RNNs.\n", "\n", "#If we return something, that will be the new value for loss. Since the recorder is always called first,\n", "#it will have the raw loss.\n", "\n", "def on_backward_end(self, **kwargs)->None: pass\n", "#Here is the place to run some code that needs to be executed after the gradients have been computed but\n", "#before the optimizer is called.\n", "#Example: deal with weight_decay in AdamW\n", "\n", "def on_step_end(self, **kwargs)->None: pass\n", "#Here is the place to run some code that needs to be executed after the optimizer step but before the gradients\n", "#are zeroed\n", "#Example: can't think of any that couldn't be done in on_batch_end but maybe someone will need this one day.\n", "\n", "def on_batch_end(self, **kwargs)->None: pass\n", "#Here is the place to run some code that needs to be executed after a batch is fully done.\n", "#Example: change the values of the hyperparameters (if we don't do it on_batch_begin instead)\n", "\n", "#If we return true, the current epoch is interrupted (example: lr_finder stops the training when the loss explodes)\n", "\n", "def on_epoch_end(self, **kwargs)->bool: return False\n", "#Here is the place to run some code that needs to be executed at the end of an epoch.\n", "#Example: Save the model if we have a new best validation loss/metric.\n", "\n", "#If we return true, the training stops (example: early stopping)\n", "\n", "def on_train_end(self, **kwargs)->None: pass\n", "#Here is the place to tidy everything. It's always executed even if there was an error during the training loop,\n", "#and has an extra kwarg named exception to check if there was an exception or not.\n", "#Examples: save log_files, load best model found during training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The idea is that one thing is entirely done in a callback so that it's easily read. By using this trick, we will have different methods categorized in different callbacks where we will find clearly stated all the interventions the method makes in training. For instance in the last LRFinder callback, on top of running the fit function with exponentially growing lrs, it needs to handle some preparation and clean-up, and all this code should be in the same callback so we know exactly what LRFinder is doing and where to look if we need to change something. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tests for change of optimizers/dataloaders" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Changing directly opt.opt or data.train_dl/data.valid_dl changes the corresponding item in the fit function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, dl_tfms=cifar_norm)\n", "data1 = DataBunch.create(train_ds, valid_ds, bs=32, train_tfm=tfms, dl_tfms=cifar_norm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CbTest():\n", " def __init__(self, learn, new_data):\n", " self.learn,self.new_data = learn,new_data\n", " \n", " def call_me(self):\n", " self.learn.data.train_dl = self.new_data.train_dl\n", " self.learn.data.valid_dl = self.new_data.valid_dl" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = CbTest(learn, data1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(data, cb):\n", " x,y = next(iter(data.train_dl))\n", " print(x.size())\n", " cb.call_me()\n", " x,y = next(iter(data.train_dl))\n", " print(x.size())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(learn.data, cb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.opt = OptimWrapper(optim.SGD(model.parameters(), 1e-2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CbTest():\n", " def __init__(self, learn, new_opt): self.learn,self.new_opt = learn,new_opt\n", " def call_me(self): self.learn.opt.opt = self.new_opt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = CbTest(learn, optim.Adam)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(opt, cb):\n", " print(opt.opt)\n", " cb.call_me()\n", " print(opt.opt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(learn.opt,cb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Changing directly opt or data doesn't change anything inside the fit function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, dl_tfms=cifar_norm)\n", "data1 = DataBunch.create(train_ds, valid_ds, bs=32, train_tfm=tfms, dl_tfms=cifar_norm)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CbTest():\n", " def __init__(self, learn, new_data): self.learn,self.new_data = learn,new_data\n", " def call_me(self): self.learn.data = self.new_data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.data = data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = CbTest(learn, data1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(data, cb):\n", " x,y = next(iter(data.train_dl))\n", " print(x.size())\n", " cb.call_me()\n", " x,y = next(iter(data.train_dl))\n", " print(x.size())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(learn.data, cb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.opt = optim.SGD" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CbTest():\n", " def __init__(self, learn, new_opt): self.learn,self.new_opt = learn,new_opt\n", " def call_me(self): self.learn.opt = self.new_opt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = CbTest(learn, optim.Adam)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test(opt, cb):\n", " print(opt)\n", " cb.call_me()\n", " print(opt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test(learn.opt,cb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }