{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nb_004a import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ipywidgets import IntProgress, HBox, HTML, VBox\n", "from time import time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def format_time(t):\n", " t = int(t)\n", " h,m,s = t//3600, (t//60)%60, t%60\n", " if h!= 0: return f'{h}:{m:02d}:{s:02d}'\n", " else: return f'{m:02d}:{s:02d}'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will attempt to build a progress bar that fits our stylistic and informational needs. To build a progress bar in Jupyter Notebook we are going to leverage the ipywidgets package that contains widgets designed to either build interactive features or better display information. In our case the objective will be the latter. For more information on the specific functions we will use, you can find them here: [IntProgress](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#IntProgress), [HBox](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#HBox), [HTML](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#HTML), [VBox](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#VBox). \n", "\n", "We will use a HBox (horizontal box) where we will include the IntProgress widget (to display progress in a progress bar) and HTML widget (to write the time and batches information beside the progress bar). We will finally use VBox to fit each HBox, one below the other in a vertical box." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ProgressBar():\n", " update_every = 0.2\n", " \n", " def __init__(self, gen, display=True, leave=True, parent=None):\n", " self._gen,self.total = gen,len(gen)\n", " if parent is None: self.leave,self.display = leave,display\n", " else:\n", " self.leave,self.display=False,False\n", " parent.add_child(self)\n", " self.comment = ''\n", " \n", " def on_iter_begin(self): pass\n", " def on_interrupt(self): pass\n", " def on_iter_end(self): pass\n", " def on_update(self, val, text): pass\n", " \n", " def __iter__(self):\n", " self.on_iter_begin()\n", " self.update(0)\n", " try:\n", " for i,o in enumerate(self._gen):\n", " yield o\n", " self.update(i+1)\n", " except: self.on_interrupt()\n", " self.on_iter_end()\n", " \n", " def update(self, val):\n", " if val == 0:\n", " self.start_t = self.last_t = time()\n", " self.pred_t = 0\n", " self.last_v,self.wait_for = 0,1\n", " self.update_bar(0)\n", " elif val >= self.last_v + self.wait_for:\n", " cur_t = time()\n", " avg_t = (cur_t - self.start_t) / val\n", " self.wait_for = max(int(self.update_every / avg_t),1)\n", " self.pred_t = avg_t * self.total\n", " self.last_v,self.last_t = val,cur_t\n", " self.update_bar(val)\n", " \n", " def update_bar(self, val):\n", " elapsed_t = self.last_t - self.start_t\n", " remaining_t = format_time(self.pred_t - elapsed_t)\n", " elapsed_t = format_time(elapsed_t)\n", " end = '' if len(self.comment) == 0 else f' {self.comment}'\n", " self.on_update(val, f'{100 * val/self.total:.2f}% [{val}/{self.total} {elapsed_t}<{remaining_t}{end}]')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class NBProgressBar(ProgressBar):\n", " \n", " def __init__(self,gen, display=True, leave=True, parent=None):\n", " self.progress,self.text = IntProgress(min=0, max=len(gen)), HTML()\n", " self.box = HBox([self.progress, self.text])\n", " super().__init__(gen, display, leave, parent)\n", " \n", " def on_iter_begin(self): \n", " if self.display: display(self.box)\n", " def on_interrupt(self): self.progress.bar_style = 'danger'\n", " def on_iter_end(self):\n", " if not self.leave: self.box.close()\n", " \n", " def on_update(self, val, text):\n", " self.text.value = text\n", " self.progress.value = val" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ConsoleProgressBar(ProgressBar):\n", " length:int=50\n", " fill:str='█'\n", " \n", " def __init__(self,gen, display=True, leave=True, parent=None):\n", " self.max_len,self.prefix = 0,''\n", " super().__init__(gen, display, leave, parent)\n", " \n", " def on_iter_end(self):\n", " if not self.leave: \n", " print(f'\\r{self.prefix}' + ' ' * (self.max_len - len(f'\\r{self.prefix}')), end = '\\r')\n", " \n", " def on_update(self, val, text):\n", " if self.display:\n", " filled_len = int(self.length * val // self.total)\n", " bar = self.fill * filled_len + '-' * (self.length - filled_len)\n", " to_write = f'\\r{self.prefix} |{bar}| {text}'\n", " if len(to_write) > self.max_len: self.max_len=len(to_write)\n", " print(to_write, end = '\\r') " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MasterBar():\n", " def __init__(self, gen, cls):\n", " self.first_bar = cls(gen, display=False)\n", " \n", " def on_iter_begin(self): pass\n", " \n", " def __iter__(self):\n", " self.on_iter_begin()\n", " for o in self.first_bar:\n", " yield o\n", " \n", " def add_child(self, child): pass\n", " def write(self, line): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class NBMasterBar(MasterBar):\n", " def __init__(self, gen):\n", " super().__init__(gen, NBProgressBar)\n", " self.text = HTML()\n", " self.vbox = VBox([self.first_bar.box, self.text])\n", " \n", " def on_iter_begin(self): display(self.vbox)\n", " \n", " def add_child(self, child):\n", " self.child = child\n", " self.vbox.children = [self.first_bar.box, self.text, child.box]\n", " \n", " def write(self, line):\n", " self.text.value += line + '

'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ConsoleMasterBar(MasterBar):\n", " def __init__(self, gen):\n", " super().__init__(gen, ConsoleProgressBar)\n", " \n", " def add_child(self, child):\n", " self.child = child\n", " self.child.prefix = f'Epoch {self.first_bar.last_v+1}/{self.first_bar.total} :'\n", " self.child.display = True\n", " \n", " def write(self, line):\n", " print(line)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are defining two types of bar classes: notebook bar and console bar. This is because, if we do not use Jupyter Notebook and instead we use a console bar, the widgets we integrated into the code will not work. Thus, you will see that when building the Console bar, we need to find alternatives to the notebook widgets. For example cannot use IntProgress to build our progress bar and instead we need to literally define the character that will build our bar: '█'.\n", "\n", "Note also that when training we will need a master bar that records our progress during all the training phase and we will also need intermediate progress bars that show the progress with each epoch.\n", "\n", "For each of these two situations we will use classes as building blocks, leveraging the pythonic concept of [inheritance](https://www.python-course.eu/python3_inheritance.php). Notice for example, that both the console and notebook bars will need a progress bar class (i.e. we can use it as a base class for both). Also, the master bar is a specific case of a progress bar (i.e. we can use the NBProgressBar and ConsoleProgressBar as a base class for the corresponding derived master bar class and also use these classes alone when we need to show the progress of each epoch).\n", "\n", "The examples below effectively illustrate both concepts.\n", "\n", "It is important to point out that the way the classes are built allows for the contents of the generator function that is passed as input to be used in the same loop. In this example this concept is represented by the use of 'j' and 'i' in the loop but during training the contents of our generator will be our training batches' data. This functionality is implemented by leveraging the ['yield'](https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do) keyword in Python.\n", "\n", "Finally, notice that in ProgressBar we are a establishing a minimum time elapsed before we update any bar. This is important since we are interested in seeing progress frequently but not more frequently that we can observe, since printing out slows down the training. Our choice of update_every is in this case 0.2 (seconds)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import sleep\n", "mb = NBMasterBar(range(5))\n", "\n", "for j in mb:\n", " for i in NBProgressBar(range(0, 50), parent=mb):\n", " sleep(0.01)\n", " mb.child.comment = str(i)\n", " mb.write(f'Epoch {j+1}: accuracy: {7.5*j+5}%')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import sleep\n", "mb = ConsoleMasterBar(range(5))\n", "\n", "for j in mb:\n", " for i in ConsoleProgressBar(range(0, 50), parent=mb):\n", " sleep(0.01)\n", " mb.child.comment = str(i)\n", " mb.write(f'Epoch {j+1}: accuracy: {7.5*j+5}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inspiration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def print_progress(iteration:int, total:int, prefix:str='', suffix:str='', decimals:int=1, length:int=50, fill:str='█'):\n", " \"Call in a loop to create terminal progress bar\"\n", " iteration += 1\n", " percent = (\"{0:.\" + str(decimals) + \"f}\").format(100 * (iteration / float(total)))\n", " filled_len = int(length * iteration // total)\n", " bar = fill * filled_len + '-' * (length - filled_len)\n", " print('\\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\\r')\n", " if iteration == total: print()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(0, 50):\n", " sleep(0.01)\n", " print_progress(i, 50, suffix=f'Epoch {j+1}: accuracy: {7.5*j+5}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test with training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are going to use the functions we built to actually train a model and see them in action. For this we have to edit several functions we defined earlier.\n", "\n", "In particular, we need to include the progress bar in the training loop and add some printing (loss before each backward pass and epoch, loss and metrics after each epoch is finished)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class DeviceDataLoader():\n", " dl: DataLoader\n", " device: torch.device\n", " half: bool = False\n", " \n", " def __len__(self): return len(self.dl)\n", " def __iter__(self):\n", " self.gen = (to_device(self.device,o) for o in self.dl)\n", " if self.half: self.gen = (to_half(o) for o in self.gen)\n", " return iter(self.gen)\n", "\n", " @classmethod\n", " def create(cls, *args, device=default_device, **kwargs):\n", " return cls(DataLoader(*args, **kwargs), device=device, half=False)\n", "\n", "nb_002b.DeviceDataLoader = DeviceDataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit(epochs, model, loss_fn, opt, data, callbacks=None, metrics=None, pbar=None):\n", " cb_handler = CallbackHandler(callbacks)\n", " cb_handler.on_train_begin()\n", " if pbar is None: pbar = NBMasterBar(range(epochs))\n", " \n", " for epoch in pbar:\n", " model.train()\n", " cb_handler.on_epoch_begin()\n", " \n", " for xb,yb in NBProgressBar(data.train_dl, parent=pbar):\n", " xb, yb = cb_handler.on_batch_begin(xb, yb)\n", " loss,_ = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)\n", " if cb_handler.on_batch_end(loss): break\n", " \n", " if hasattr(data,'valid_dl') and data.valid_dl is not None:\n", " model.eval()\n", " with torch.no_grad():\n", " *val_metrics,nums = zip(*[loss_batch(model, xb, yb, loss_fn, cb_handler=cb_handler, metrics=metrics)\n", " for xb,yb in NBProgressBar(data.valid_dl, parent=pbar)])\n", " val_metrics = [np.sum(np.multiply(val,nums)) / np.sum(nums) for val in val_metrics]\n", " \n", " else: val_metrics=None\n", " if cb_handler.on_epoch_end(val_metrics): break\n", " \n", " cb_handler.on_train_end()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Learner():\n", " data: DataBunch\n", " model: nn.Module\n", " opt_fn: Callable = optim.SGD\n", " loss_fn: Callable = F.cross_entropy\n", " metrics: Collection[Callable] = None\n", " true_wd: bool = False\n", " def __post_init__(self): self.model = self.model.to(self.data.device)\n", "\n", " def fit(self, epochs, lr, wd=0., callbacks=None):\n", " self.opt = OptimWrapper(self.opt_fn(self.model.parameters(), lr), wd=wd, true_wd=self.true_wd)\n", " pbar = NBMasterBar(range(epochs))\n", " self.recorder = Recorder(self.opt, self.data.train_dl, pbar)\n", " if callbacks is None: callbacks = []\n", " callbacks = [self.recorder]+callbacks\n", " fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks, metrics=self.metrics, pbar=pbar)\n", " \n", " def lr_find(self, start_lr=1e-5, end_lr=10, num_it=100):\n", " cb = LRFinder(self, start_lr, end_lr, num_it)\n", " a = int(np.ceil(num_it/len(self.data.train_dl)))\n", " self.fit(a, start_lr, callbacks=[cb])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Recorder(Callback):\n", " opt: torch.optim\n", " train_dl: DeviceDataLoader = None\n", " pbar: MasterBar = None\n", "\n", " def on_train_begin(self, **kwargs):\n", " self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]\n", " \n", " def on_batch_begin(self, **kwargs):\n", " self.lrs.append(self.opt.lr)\n", " self.moms.append(self.opt.mom)\n", " \n", " def on_backward_begin(self, smooth_loss, **kwargs):\n", " #We record the loss here before any other callback has a chance to modify it.\n", " self.losses.append(smooth_loss)\n", " if self.pbar is not None and hasattr(self.pbar,'child'): \n", " self.pbar.child.comment = f'{smooth_loss:.4f}'\n", " \n", " def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs):\n", " self.nb_batches.append(num_batch)\n", " if last_metrics is not None:\n", " self.val_losses.append(last_metrics[0])\n", " if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])\n", " self.pbar.write(f'{epoch}, {smooth_loss}, {last_metrics}')\n", " else: self.pbar.write(f'{epoch}, {smooth_loss}')\n", " \n", " def plot_lr(self, show_moms=False):\n", " iterations = list(range(len(learn.recorder.lrs)))\n", " if show_moms:\n", " _, axs = plt.subplots(1,2, figsize=(12,4))\n", " axs[0].plot(iterations, self.lrs)\n", " axs[1].plot(iterations, self.moms)\n", " else: plt.plot(iterations, self.lrs)\n", " \n", " def plot(self, skip_start=10, skip_end=5):\n", " lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]\n", " losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]\n", " _, ax = plt.subplots(1,1)\n", " ax.plot(lrs, losses)\n", " ax.set_xscale('log')\n", " \n", " def plot_losses(self):\n", " _, ax = plt.subplots(1,1)\n", " iterations = list(range(len(self.losses)))\n", " ax.plot(iterations, self.losses)\n", " val_iter = self.nb_batches\n", " val_iter = np.array(val_iter).cumsum()\n", " ax.plot(val_iter, self.val_losses)\n", " \n", " def plot_metrics(self):\n", " assert len(self.metrics) != 0, \"There is no metrics to plot.\"\n", " _, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))\n", " val_iter = self.nb_batches\n", " val_iter = np.array(val_iter).cumsum()\n", " axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]\n", " for i, ax in enumerate(axes):\n", " values = [met[i] for met in self.metrics]\n", " ax.plot(val_iter, values)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATA_PATH = Path('data')\n", "PATH = DATA_PATH/'cifar10'\n", "\n", "data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))\n", "cifar_norm = normalize_tfm(mean=data_mean,std=data_std)\n", "\n", "tfms = [flip_lr_tfm(p=0.5),\n", " pad_tfm(padding=4),\n", " crop_tfm(size=32, row_pct=(0,1.), col_pct=(0,1.))]\n", "\n", "bs = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ImageDataset.from_folder(PATH/'train', classes=['airplane','dog'])\n", "valid_ds = ImageDataset.from_folder(PATH/'test', classes=['airplane','dog'])\n", "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, valid_tfm=[], num_workers=4)\n", "len(data.train_dl), len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A bit more fancy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are going to integrate useful graphs to the progress bar so that, when we finish training, we automatically show how our loss evolved during training. The graph will also be saved in the vertical box of our progress bar object for future reference. \n", "\n", "To implement this we need to add a function within the MasterBar base and derived classes that progressively updates the graph and add a Callback that progressively sends this function the necessary information." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MasterBar():\n", " def __init__(self, gen, cls):\n", " self.first_bar = cls(gen, display=False)\n", " \n", " def on_iter_begin(self): pass\n", " def on_iter_end(self): pass\n", " \n", " def __iter__(self):\n", " self.on_iter_begin()\n", " for o in self.first_bar:\n", " yield o\n", " self.on_iter_end()\n", " \n", " def add_child(self, child): pass\n", " def write(self, line): pass\n", " def update_graph(self, graphs, x_bounds, y_bounds): pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ipywidgets import widgets\n", "from IPython.display import clear_output\n", "from ipywidgets.widgets.interaction import show_inline_matplotlib_plots\n", "\n", "class NBMasterBar(MasterBar):\n", " names = ['train', 'valid']\n", " def __init__(self, gen, show_graph=True):\n", " super().__init__(gen, NBProgressBar)\n", " self.text = HTML()\n", " if show_graph:\n", " self.out = widgets.Output()\n", " self.vbox = VBox([self.first_bar.box, self.text, self.out])\n", " self.fig, self.ax = plt.subplots(1, figsize=(6,4))\n", " else: self.vbox = VBox([self.first_bar.box, self.text])\n", " self.show_graph = show_graph\n", " \n", " def on_iter_begin(self): display(self.vbox)\n", " def on_iter_end(self): \n", " if self.show_graph: self.fig.clear()\n", " \n", " def add_child(self, child):\n", " self.child = child\n", " self.vbox.children = [self.first_bar.box, self.text, child.box, self.out]\n", " \n", " def write(self, line):\n", " self.text.value += line + '

'\n", " \n", " def update_graph(self, graphs, x_bounds, y_bounds):\n", " if not self.show_graph: return\n", " self.out = widgets.Output()\n", " self.ax.clear()\n", " for g,n in zip(graphs,self.names): self.ax.plot(*g, label=n)\n", " self.ax.legend(loc='upper right')\n", " self.ax.set_xlim(*x_bounds)\n", " self.ax.set_ylim(*y_bounds)\n", " with self.out:\n", " clear_output(wait=True)\n", " display(self.ax.figure)\n", " self.vbox.children = [self.first_bar.box, self.text, self.out]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Recorder(Callback):\n", " opt: torch.optim\n", " nb_epoch:int\n", " train_dl: DeviceDataLoader = None\n", " pbar: MasterBar = None\n", "\n", " def on_train_begin(self, **kwargs):\n", " self.losses,self.val_losses,self.lrs,self.moms,self.metrics,self.nb_batches = [],[],[],[],[],[]\n", " \n", " def on_batch_begin(self, **kwargs):\n", " self.lrs.append(self.opt.lr)\n", " self.moms.append(self.opt.mom)\n", " \n", " def on_backward_begin(self, smooth_loss, **kwargs):\n", " #We record the loss here before any other callback has a chance to modify it.\n", " self.losses.append(smooth_loss)\n", " if self.pbar is not None and hasattr(self.pbar,'child'): \n", " self.pbar.child.comment = f'{smooth_loss:.4f}'\n", " \n", " def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs):\n", " self.nb_batches.append(num_batch)\n", " if last_metrics is not None:\n", " self.val_losses.append(last_metrics[0])\n", " if len(last_metrics) > 1: self.metrics.append(last_metrics[1:])\n", " self.pbar.write(f'{epoch}, {smooth_loss}, {last_metrics}')\n", " self.pbar.update_graph(*self.send_graphs())\n", " else: self.pbar.write(f'{epoch}, {smooth_loss}')\n", " \n", " def plot_lr(self, show_moms=False):\n", " iterations = list(range(len(learn.recorder.lrs)))\n", " if show_moms:\n", " _, axs = plt.subplots(1,2, figsize=(12,4))\n", " axs[0].plot(iterations, self.lrs)\n", " axs[1].plot(iterations, self.moms)\n", " else: plt.plot(iterations, self.lrs)\n", " \n", " def plot(self, skip_start=10, skip_end=5):\n", " lrs = self.lrs[skip_start:-skip_end] if skip_end > 0 else self.lrs[skip_start:]\n", " losses = self.losses[skip_start:-skip_end] if skip_end > 0 else self.losses[skip_start:]\n", " _, ax = plt.subplots(1,1)\n", " ax.plot(lrs, losses)\n", " ax.set_xscale('log')\n", " \n", " def plot_losses(self):\n", " _, ax = plt.subplots(1,1)\n", " iterations = list(range(len(self.losses)))\n", " ax.plot(iterations, self.losses)\n", " val_iter = self.nb_batches\n", " val_iter = np.array(val_iter).cumsum()\n", " ax.plot(val_iter, self.val_losses)\n", " \n", " def plot_metrics(self):\n", " assert len(self.metrics) != 0, \"There is no metrics to plot.\"\n", " _, axes = plt.subplots(len(self.metrics[0]),1,figsize=(6, 4*len(self.metrics[0])))\n", " val_iter = np.array(self.nb_batches).cumsum()\n", " axes = axes.flatten() if len(self.metrics[0]) != 1 else [axes]\n", " for i, ax in enumerate(axes):\n", " values = [met[i] for met in self.metrics]\n", " ax.plot(val_iter, values)\n", " \n", " def send_graphs(self):\n", " iters = list(range(len(self.losses))) + [None] * (self.nb_epoch * self.nb_batches[-1] - len(self.losses))\n", " losses = self.losses + [None] * (self.nb_epoch * self.nb_batches[-1] - len(self.losses))\n", " val_iter = np.array(self.nb_batches).cumsum()\n", " val_losses = self.val_losses + [None] * (self.nb_epoch - len(val_iter))\n", " val_iter = list(val_iter) + [None] * (self.nb_epoch - len(val_iter))\n", " x_bounds = (0, (self.nb_epoch - len(self.nb_batches)) * self.nb_batches[-1] + len(self.losses))\n", " y_bounds = (0, max((max(self.losses), max(self.val_losses))))\n", " return [(iters, losses), (val_iter, val_losses)], x_bounds, y_bounds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class Learner():\n", " data: DataBunch\n", " model: nn.Module\n", " opt_fn: Callable = optim.SGD\n", " loss_fn: Callable = F.cross_entropy\n", " metrics: Collection[Callable] = None\n", " true_wd: bool = False\n", " def __post_init__(self): self.model = self.model.to(self.data.device)\n", "\n", " def fit(self, epochs, lr, wd=0., callbacks=None):\n", " self.opt = OptimWrapper(self.opt_fn(self.model.parameters(), lr), wd=wd, true_wd=self.true_wd)\n", " pbar = NBMasterBar(range(epochs))\n", " self.recorder = Recorder(self.opt, epochs, self.data.train_dl, pbar)\n", " if callbacks is None: callbacks = []\n", " callbacks = [self.recorder]+callbacks\n", " fit(epochs, self.model, self.loss_fn, self.opt, self.data, callbacks=callbacks, metrics=self.metrics, pbar=pbar)\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ImageDataset.from_folder(PATH/'train', classes=['airplane','dog'])\n", "valid_ds = ImageDataset.from_folder(PATH/'test', classes=['airplane','dog'])\n", "data = DataBunch.create(train_ds, valid_ds, bs=bs, train_tfm=tfms, valid_tfm=[], num_workers=4)\n", "len(data.train_dl), len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Darknet([1, 2, 2, 2, 2], num_classes=2, nf=16)\n", "learn = Learner(data, model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,0.01)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.pbar.vbox" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }