{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Trial History and Replay\n", "\n", "This guide will give a breif overview of the history returned by a `Trial` from a call to fit or evaluate.\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup.\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.0.dev\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Simple Example\n", "\n", "We first create some data and a simple model to train on." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class BasicModel(nn.Module):\n", " def __init__(self):\n", " super(BasicModel, self).__init__()\n", " self.linear1 = nn.Linear(100, 25)\n", " self.linear2 = nn.Linear(25, 1)\n", "\n", " def forward(self, x):\n", " x = self.linear1(x)\n", " x = self.linear2(x)\n", " return torch.sigmoid(x).squeeze(1)\n", " \n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "\n", "import numpy as np\n", "from sklearn.datasets.samples_generator import make_blobs\n", "\n", "X, Y = make_blobs(n_samples=2048, n_features=100, centers=2, cluster_std=10, random_state=1)\n", "X = (X - X.mean()) / X.std()\n", "Y[np.where(Y == 0)] = -1\n", "X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)\n", "traingen = DataLoader(TensorDataset(X, Y), batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll run the model for a few epochs to obtain a history." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3cd46292b0534f65a7514b0ea78ba551", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=20), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import torch.optim as optim\n", "import torch.nn.functional as F\n", "\n", "from torchbearer import Trial\n", "\n", "model = BasicModel()\n", "\n", "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy,\n", " metrics=['mse', 'acc', 'loss'])\n", "trial.with_train_generator(traingen)\n", "history = trial.run(epochs=20, verbose=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The History\n", "\n", "The history is a list of metric dictionaries from each epoch of training. History also includes the number of training and validation steps from each epoch. Let's take a look" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "20\n", "{'running_mse': 0.9208962321281433, 'running_binary_acc': 0.960156261920929, 'running_loss': -0.019676249474287033, 'mse': 0.8676411509513855, 'binary_acc': 0.96923828125, 'loss': -0.15272381901741028, 'train_steps': 16, 'validation_steps': None}\n" ] } ], "source": [ "print(len(history))\n", "print(history[5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### With Pandas\n", "Suppose that we wanted to use pandas to plot our training progress or similar, we could do that with the following" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " binary_acc loss mse running_binary_acc running_loss \\\n", "0 0.896484 0.316458 1.067553 0.887784 0.334222 \n", "1 0.923340 0.225443 1.025509 0.904803 0.286757 \n", "2 0.939453 0.133693 0.984494 0.916061 0.240595 \n", "3 0.953613 0.040475 0.944467 0.933594 0.167838 \n", "4 0.963867 -0.054826 0.905476 0.949063 0.074993 \n", "5 0.969238 -0.152724 0.867641 0.960156 -0.019676 \n", "6 0.973633 -0.253658 0.831138 0.967812 -0.116728 \n", "7 0.978027 -0.358019 0.796175 0.972656 -0.216635 \n", "8 0.979004 -0.466167 0.762976 0.975781 -0.319814 \n", "9 0.980469 -0.578459 0.731756 0.978750 -0.426640 \n", "10 0.981445 -0.695270 0.702708 0.980313 -0.537480 \n", "11 0.984375 -0.817010 0.675986 0.982188 -0.652704 \n", "12 0.983887 -0.944141 0.651693 0.983594 -0.772715 \n", "13 0.983887 -1.077197 0.629881 0.984375 -0.897960 \n", "14 0.984863 -1.216788 0.610547 0.984531 -1.028948 \n", "15 0.985352 -1.363614 0.593634 0.985000 -1.166263 \n", "16 0.985352 -1.518468 0.579039 0.985469 -1.310573 \n", "17 0.985352 -1.682242 0.566622 0.985625 -1.462638 \n", "18 0.985352 -1.855926 0.556210 0.985625 -1.623313 \n", "19 0.984375 -2.040609 0.547611 0.985469 -1.793551 \n", "\n", " running_mse train_steps validation_steps \n", "0 1.077145 16 None \n", "1 1.054281 16 None \n", "2 1.033050 16 None \n", "3 1.000690 16 None \n", "4 0.960275 16 None \n", "5 0.920896 16 None \n", "6 0.882645 16 None \n", "7 0.845669 16 None \n", "8 0.810160 16 None \n", "9 0.776331 16 None \n", "10 0.744397 16 None \n", "11 0.714554 16 None \n", "12 0.686969 16 None \n", "13 0.661763 16 None \n", "14 0.639006 16 None \n", "15 0.618716 16 None \n", "16 0.600858 16 None \n", "17 0.585348 16 None \n", "18 0.572062 16 None \n", "19 0.560841 16 None \n" ] } ], "source": [ "import pandas as pd\n", "\n", "frame = pd.DataFrame.from_records(history)\n", "print(frame)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now use all of the built-in pandas functions, such as plotting" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "frame.reset_index().plot('index', 'binary_acc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Replay a Trial\n", "\n", "One of the perks of history is the ability to replay a trial. We'll look at two of the replay options here. First we can just replay the whole training process, this time with a different verbosity." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6eef5ee55c334f119d8a60fd2af57be1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2008f381538142618cff8fc050eb307f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "142a32f36308426e9fe59ad5e2092701", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c8c7783d52ff45dca7874f4f728d0094", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='3/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15261a70fcb8445d84aae00e59097b9b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='4/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31d57a94b04d47e08bf29c14dada538a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='5/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1d2d2e2eba8542faba70141ba4302ba4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='6/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9aa166f3fe2940c3b652169bb00020d7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='7/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73436eeb6a6a46bd8086400d87f93624", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='8/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fee3517b93b1460bbab3997cc07f747d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='9/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7575c42b2b484e369b6d7cf58e8b7024", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='10/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "221c36be827a405b91b62e7d6cc748cc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='11/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b7869ba8ae444650a8ca9856610f3a75", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='12/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ae4218c4a8c4ad3b4e83a891cdb798d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='13/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1974270d05414437b5345804820abcda", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='14/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7a7899659a7d46beaa9a918a2886f3c6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='15/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "211dc56c1dae4c57af5bdca3393dc7c4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='16/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c92266f9d924abebdc66383f90b3e28", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='17/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fcda5e5aca274e5c8b8d17aa526c9a49", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='18/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "efbd88786f9649108315bb275f1e21ea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='19/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "_ = trial.replay(verbose=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This may be more output than we desire, and so we can instead use the `one_batch` flag to just simulate one batch per epoch." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "af55bc1e2eba480f8c288d9d17d20e67", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c96ab8cd8e7e41babd0b326ec8f13d5a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15e38957d69e4218947f7e64bc3a5860", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4fea98cef5ea40ccb04fa47ee0dc5cd5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='3/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "13f62e9db6d942e9995f903ab796c72d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='4/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8dfb537c62a749f58b075ba5e42861ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='5/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fd9c4784b6574fec8c549838e768a28e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='6/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48626c90db0843c7b59164db6d008cf7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='7/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2ca960b335344528dd3ae61aa823bc4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='8/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5b2dbd82198f4df78a3538a85d1a3205", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='9/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e47dd50497714db187edf8996de1515e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='10/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29211e3e9ac8484a958dd85a9e81a671", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='11/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4cd52e99a1f4451ae13cdef4a007c27", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='12/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f7c4c301a0c74de3bd1103bd0648054b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='13/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6102be2e3194d94a8926137284ab1a6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='14/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ea8b0d4690f48468e07aa2ba8917dc4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='15/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da98b8d23daf461197613f8ce547dcc7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='16/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c291f91c91ed465ab55fec1ebcf07ea1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='17/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a37c3f0a2b14538a6e5be77b61e9813", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='18/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e0fbef3364446c7ac4cb68971f3abfc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='19/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "_ = trial.replay(verbose=2, one_batch=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So that's history and replaying in `torchbearer`. Be sure to have a look at our other examples at [pytorchbearer.org](http://www.pytorchbearer.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }