{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Trial History and Replay\n", "\n", "This guide will give a breif overview of the history returned by a `Trial` from a call to fit or evaluate.\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup.\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.0.dev\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Simple Example\n", "\n", "We first create some data and a simple model to train on." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class BasicModel(nn.Module):\n", " def __init__(self):\n", " super(BasicModel, self).__init__()\n", " self.linear1 = nn.Linear(100, 25)\n", " self.linear2 = nn.Linear(25, 1)\n", "\n", " def forward(self, x):\n", " x = self.linear1(x)\n", " x = self.linear2(x)\n", " return torch.sigmoid(x).squeeze(1)\n", " \n", "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "\n", "import numpy as np\n", "from sklearn.datasets.samples_generator import make_blobs\n", "\n", "X, Y = make_blobs(n_samples=2048, n_features=100, centers=2, cluster_std=10, random_state=1)\n", "X = (X - X.mean()) / X.std()\n", "Y[np.where(Y == 0)] = -1\n", "X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)\n", "traingen = DataLoader(TensorDataset(X, Y), batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll run the model for a few epochs to obtain a history." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3cd46292b0534f65a7514b0ea78ba551", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=20), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import torch.optim as optim\n", "import torch.nn.functional as F\n", "\n", "from torchbearer import Trial\n", "\n", "model = BasicModel()\n", "\n", "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy,\n", " metrics=['mse', 'acc', 'loss'])\n", "trial.with_train_generator(traingen)\n", "history = trial.run(epochs=20, verbose=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The History\n", "\n", "The history is a list of metric dictionaries from each epoch of training. History also includes the number of training and validation steps from each epoch. Let's take a look" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "20\n", "{'running_mse': 0.9208962321281433, 'running_binary_acc': 0.960156261920929, 'running_loss': -0.019676249474287033, 'mse': 0.8676411509513855, 'binary_acc': 0.96923828125, 'loss': -0.15272381901741028, 'train_steps': 16, 'validation_steps': None}\n" ] } ], "source": [ "print(len(history))\n", "print(history[5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### With Pandas\n", "Suppose that we wanted to use pandas to plot our training progress or similar, we could do that with the following" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " binary_acc loss mse running_binary_acc running_loss \\\n", "0 0.896484 0.316458 1.067553 0.887784 0.334222 \n", "1 0.923340 0.225443 1.025509 0.904803 0.286757 \n", "2 0.939453 0.133693 0.984494 0.916061 0.240595 \n", "3 0.953613 0.040475 0.944467 0.933594 0.167838 \n", "4 0.963867 -0.054826 0.905476 0.949063 0.074993 \n", "5 0.969238 -0.152724 0.867641 0.960156 -0.019676 \n", "6 0.973633 -0.253658 0.831138 0.967812 -0.116728 \n", "7 0.978027 -0.358019 0.796175 0.972656 -0.216635 \n", "8 0.979004 -0.466167 0.762976 0.975781 -0.319814 \n", "9 0.980469 -0.578459 0.731756 0.978750 -0.426640 \n", "10 0.981445 -0.695270 0.702708 0.980313 -0.537480 \n", "11 0.984375 -0.817010 0.675986 0.982188 -0.652704 \n", "12 0.983887 -0.944141 0.651693 0.983594 -0.772715 \n", "13 0.983887 -1.077197 0.629881 0.984375 -0.897960 \n", "14 0.984863 -1.216788 0.610547 0.984531 -1.028948 \n", "15 0.985352 -1.363614 0.593634 0.985000 -1.166263 \n", "16 0.985352 -1.518468 0.579039 0.985469 -1.310573 \n", "17 0.985352 -1.682242 0.566622 0.985625 -1.462638 \n", "18 0.985352 -1.855926 0.556210 0.985625 -1.623313 \n", "19 0.984375 -2.040609 0.547611 0.985469 -1.793551 \n", "\n", " running_mse train_steps validation_steps \n", "0 1.077145 16 None \n", "1 1.054281 16 None \n", "2 1.033050 16 None \n", "3 1.000690 16 None \n", "4 0.960275 16 None \n", "5 0.920896 16 None \n", "6 0.882645 16 None \n", "7 0.845669 16 None \n", "8 0.810160 16 None \n", "9 0.776331 16 None \n", "10 0.744397 16 None \n", "11 0.714554 16 None \n", "12 0.686969 16 None \n", "13 0.661763 16 None \n", "14 0.639006 16 None \n", "15 0.618716 16 None \n", "16 0.600858 16 None \n", "17 0.585348 16 None \n", "18 0.572062 16 None \n", "19 0.560841 16 None \n" ] } ], "source": [ "import pandas as pd\n", "\n", "frame = pd.DataFrame.from_records(history)\n", "print(frame)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now use all of the built-in pandas functions, such as plotting" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEKCAYAAAD+XoUoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt0FeW9//H3l1zIhSSQBBATEq4KqCgIiAKKtlq1FqvWqq0VtWg9rR49Lk9ra2up1mWPWmttrf15QbG13rDWS7UWqRegoqAIck+AACEIJIEECLk/vz/2QDchIRuSndnJfF5rZWX2zDN7fzPZ+WT2MzPPmHMOEREJhm5+FyAiIh1HoS8iEiAKfRGRAFHoi4gEiEJfRCRAFPoiIgGi0BcRCRCFvohIgCj0RUQCJN7vAprKzs52AwYM8LsMEZFO5ZNPPil1zvVurV3Mhf6AAQNYtGiR32WIiHQqZrYhknbq3hERCRCFvohIgCj0RUQCJOb69JtTV1dHcXEx1dXVfpfSZSUlJZGbm0tCQoLfpYhIFHWK0C8uLiYtLY0BAwZgZn6X0+U45ygrK6O4uJiBAwf6XY6IRFGn6N6prq4mKytLgR8lZkZWVpY+SYkEQKcIfUCBH2XaviLB0Cm6d0Ska2hodFTurWPn3joqvK+dVbVUetO19Y1+l9gmaUkJ9M9MIT8rhbzMFFK7x17Exl5FItJpOOfYtquGDWVVbCyvYvuuGi/Ma8NC/T8Bv6u6/pDP15k/cDZ3u/HsHt3Jy0wmPyuVvMzQP4L8rBTyslLo3aO7L5+wFfoRKioq4oILLmDZsmUHzJ82bRq33norI0aM8Kky6QqccyzetJMde2rpmZJARnIC6cmh793j43ytraa+geIde9lYXsXGsiov4PeEHpdXUV134N55QpyRkZxIRnI8GckJ9E1P4pi+aWR4P09GcsL+nzHWfta2qqiqY2N5FRvK94S2k/fP8OP15fzts80H/GNITogL/SPISiHf+z60TxqnDs6Kao0K/TZ64okn2uV56uvriY/XryNoKvbW8cqnxTz70UYKtu1utk1yQtz+gExPTqBnM+GZnpxAz5REuscf+WE656BsT2ivfVN51f6995KKvc2GVX5WKpOG9t7flZGXmULf9CRSEuMCe4woIyWBE1IyOCE346BlLf3zLCrdwwdrtlNT38iovJ688v0JUa2x06XML15fzoqSynZ9zhFHp/Pzrx3Xarv6+nqmTp3K4sWLOeaYY3jmmWc4//zzeeCBBxgzZgw9evTg5ptv5o033iA5OZlXX32Vvn378vrrr/PLX/6S2tpasrKyePbZZ+nbty/Tp0+npKSEoqIisrOz2bRpE7/73e846aSTAJgwYQKPPvooI0eOPKiWjz/+mFtuuYW9e/eSnJzMU089xbHHHktDQwM/+tGPePvttzEzrrvuOm666SYWLlzIzTffzJ49e+jevTtz5swhLS2tXbejRMY5x9LiCp79aAOvLSmhuq6RE/v35L5vjOSYvmn7u0Iqqg7uItm5N7QnuW/e3rqGqNSY3SORvMwUxg3MDPVRh/VT907zp1uis+seH8fg3j0Y3LvHQcsaGx3bd9ewu+bQ3V/todOFvp9Wr17Nk08+yYQJE7j22mv5wx/+cMDyPXv2MH78eO655x5++MMf8vjjj/PTn/6UiRMnsmDBAsyMJ554gvvuu49f//rXAHzyySfMmzeP5ORkZs6cydNPP81DDz3EmjVrqKmpaTbwAYYNG8YHH3xAfHw877zzDj/5yU94+eWXeeyxx1i/fj2LFy8mPj6e8vJyamtrueyyy3jhhRcYO3YslZWVJCcnR317yYH21NTz2pISnv1oA8s2V5KSGMdFo3L59il5HJ9z8J5hJGrqG6jYWxc6OFpVR21D2w6E9kpJjNkDkF1Zt25G3/Qk+nbAa3W632wke+TR0r9/fyZMCH30uvLKK3n44YcPWJ6YmMgFF1wAwMknn8zs2bOB0MVll112GVu2bKG2tvaAC6CmTJmyP4AvvfRS7r77bu6//35mzJjB1Vdf3WItFRUVTJ06lYKCAsyMuro6AN555x1uuOGG/V1FmZmZfP755/Tr14+xY8cCkJ6e3g5bQyK16otK/vLRRl75dDO7auoZdlQad3/9eL5+0tGkJbXtCuju8XH0SYujT1pSO1UrXV2nC30/Nf1I2/RxQkLC/nlxcXHU14c+qt10003ceuutTJkyhffee4/p06fvXyc1NXX/dEpKCmeffTavvvoqL7744iGHmP7Zz37GmWeeySuvvEJRURGTJ08GQl0HTetqbp5EV3VdA28t28KzCzayaMMOEuO7ccEJ/fj2+DxG5/XS70N8o9A/DBs3buTDDz/k1FNP5bnnnmPixIm8/vrrra5XUVFBTk4OADNnzjxk22nTpvG1r32NSZMmkZmZGdFzPv300/vnn3POOfzxj39k8uTJ+7t3hg0bRklJCQsXLmTs2LHs2rWL5ORkHTiOgvWle/jLRxt46ZNidlbVMTA7lZ9+dTiXjM6lV2qi3+WJRBb6ZnYu8FsgDnjCOferJsvzgRlAb6AcuNI5V+wtuw/4KqGrf2cDNzvX3BmtsW/48OHMnDmT733vewwdOpT/+q//iij0p0+fzqWXXkpOTg7jx49n/fr1LbY9+eSTSU9P55prrjnkc/7whz9k6tSpPPjgg5x11ln750+bNo01a9YwcuRIEhISuO6667jxxht54YUXuOmmm/Yf+H3nnXfo0ePgA0rSuqbnpm8sC52+uK50D0uLK4jvZpxzXF+uPCWfUwdr+BCJLdZa/ppZHLAGOBsoBhYCVzjnVoS1eQl4wzk308zOAq5xzn3HzE4D7gdO95rOA37snHuvpdcbM2aMa9qtsXLlSoYPH364P1unVFJSwuTJk1m1ahXdunXsKBlB2s6t2X96XVkVG8r2sLF8Lxu9c6837Tjw3PRuBkf3TCYvM4XTBmfxzTH96ZOuPnbpWGb2iXNuTGvtItnTHwcUOufWeU/8PHAhsCKszQjgf7zpd4G/edMOSAISAQMSgK2R/ABB9Mwzz3DHHXfw4IMPdnjgB1VdQyNrtu7i8+IKlm6uYN323Wwsq2JLZfVB56bnZ6UwIDuVM44JnZve3ztXPadnMoltOD9epCNFEvo5wKawx8XAKU3aLAEuIdQFdBGQZmZZzrkPzexdYAuh0P+9c25l28vumq666iquuuqqA+Y99dRT/Pa3vz1g3oQJE3jkkUc6srQuoaHRUbhtN0uLd/L55gqWFlewYkvl/vFe0pLiOaZvGqcMytp/ufy+cPfrknmR9hZJ6Df3Tm/aJ3Qb8Hszuxr4ANgM1JvZEGA4kOu1m21mpzvnPjjgBcyuB64HyMvLa7aIoJ6Bcs0117Tav98eOulhlhY1NjrWl+0J7cEXV/D55p0s21y5/2Km1MQ4js/JYOqp+ZyQ25ORORnkZ6UE8j0mwRJJ6BcD/cMe5wIl4Q2ccyXAxQBm1gO4xDlX4YX5Aufcbm/ZW8B4Qv8Ywtd/DHgMQn36TQtISkqirKxMY+pHyb6bqCQldb5+aOcc5XtqQ+OdlFWxYkslS4tDAb/v6sakhG4cf3QGl4/rz8jcDE7I6cmg7FS6ddN7SYInktBfCAw1s4GE9uAvB74V3sDMsoFy51wj8GNCZ/IAbASuM7N7CX1iOAN46HCLzM3Npbi4mO3btx/uqhKhfbdLjEX1DY2U7Kxmw75Bvsr+My7MxvKqAy5dT4zvxoh+6Vw8OocTcjIYmduTwb1TiY9Tn7sIRBD6zrl6M7sReJvQKZsznHPLzewuYJFz7jVgMnCvmTlCe/E/8FafBZwFfE6oS+gfzrnWz3FsIiEhQbfxC4BN5VUsL6nYv9e+L9Q379hLfeN/PgAmxnejf6/QcLXjBmb+Z7ha78CqDqqKtKzVUzY7WnOnbErXtrumnt/MXsPT/y6iwQv3nikJ5Gem7L8hRX5mKnlesB+VnqSuGZEm2vOUTZGocM7x5udfcNcby9m2q4bLx+bxrXF55GWlkJHctjFpRKR5Cn3xxfrSPdz56jLmFpQyol86j155MqPzevldlkiXp9CXDlVd18Af3lvLH99bS/f4bkz/2giuHJ+vA60iHUShLx3m3dXbmP7acjaUVTHlxKP56VeHa7gCkQ6m0JeoK9m5l7vfWMFby75gUO9Unp12ChOGZPtdlkggKfQlauoaGnlq/noeeqeAhkbHbeccw3WnD+r0N78W6cwU+hIVC4vK+ekry1i9dRdnDevDL6YcR//MFL/LEgk8hb60q7LdNdz71ipmfVJMTs9kHvvOyZw9oq+GzxCJEQp9aReNjY7nFm7kvn+sZk9NPTecMZj//tIQUhL1FhOJJfqLlDZbtrmCO/62jCWbdnLKwEx++fXjGdo3ze+yRKQZCn05YpXVdTz4zzU882ERmamJ/OayE/n6STnqyhGJYQp9OWzOOV5bUsIv/76S0t01XHlKPrd95VgNnSDSCSj05bAUbtvNna8u499ryxiZm8GTU8cwMren32WJSIQU+hKRvbUN/P7dAh77YB1JCXHc/fXj+da4POI02qVIp6LQl1bNWbmVn7+2nOIde7l4VA4/Pn84vdO6+12WiBwBhb60qHhHFb94fQWzV2xlaJ8ePH/9eMYPyvK7LBFpA4W+HKS2vpEn563n4TkFANx+3jCunTBQd6QS6QIU+nKAD9eW8bNXl1G4bTdfOa4vd37tOHJ6Jvtdloi0E4W+ALCruo7pr63g5U+L6Z+ZzIyrx3DWsL5+lyUi7UyhLyzbXMEP/vIpxTv2cuOZQ/jBmUNITtRImCJdkUI/wJxzPPPhBu75+0qyeiTy/PXjGTsg0++yRCSKFPoBVbG3jttfXspby77grGF9eODSE8lMTfS7LBGJMoV+AC3ZtJMbn/uULTur+cn5w5g2cRDddJGVSCAo9APEOceM+UX86q2V9ElL4sUbTmV0Xi+/yxKRDqTQD4idVbX876ylzF6xlS8P78sDl46kZ4q6c0SCRqEfAJ9u3MFNf1nMtl3V/OyCEVw7YYCGPxYJKIV+F9bY6Hh87jruf3s1/XomMeuG0zixv0bEFAkyhX4XVb6nltteWsK/Vm3jvOOP4leXjNR49yKi0O+KFhaV89/PLaZsdy13XXgc3xmfr+4cEQEU+l1KY6Pj0ffX8uDsNeT2Suav3z+N43My/C5LRGKIQr+LqGto5Ht/+oR/rdrGV0f241cXn0BakrpzRORACv0u4vf/KuRfq7Zx5wUjuEZn54hICxT6XcAnG3bw+3cLuXh0DtdOHOh3OSISw3RXjE5ud009t774Gf0ykvjFlOP8LkdEYpz29Du5u19fwabyKp6//lT14YtIq7Sn34m9vfwLXli0iRvOGMy4gRoSWURaF1Hom9m5ZrbazArN7PZmlueb2RwzW2pm75lZbtiyPDP7p5mtNLMVZjag/coPrm2V1dz+8lKOz0nnli8f43c5ItJJtBr6ZhYHPAKcB4wArjCzEU2aPQA845wbCdwF3Bu27BngfufccGAcsK09Cg8y5xz/O2spe+saeOiyUbphuYhELJK0GAcUOufWOedqgeeBC5u0GQHM8abf3bfc++cQ75ybDeCc2+2cq2qXygPsTws28P6a7dxx/nCG9Onhdzki0olEEvo5wKawx8XevHBLgEu86YuANDPLAo4BdprZX81ssZnd731ykCNUuG0X9/x9JZOP7c2V4/P9LkdEOplIQr+5q3xck8e3AWeY2WLgDGAzUE/o7KBJ3vKxwCDg6oNewOx6M1tkZou2b98eefUBU1vfyC0vfEZq93ju+8ZIXYAlIoctktAvBvqHPc4FSsIbOOdKnHMXO+dGAXd48yq8dRd7XUP1wN+A0U1fwDn3mHNujHNuTO/evY/wR+n6HnpnDcs2V3LvxSfQJy3J73JEpBOKJPQXAkPNbKCZJQKXA6+FNzCzbDPb91w/BmaErdvLzPYl+VnAiraXHTwfry/n0ffXcvnY/nzluKP8LkdEOqlWQ9/bQ78ReBtYCbzonFtuZneZ2RSv2WRgtZmtAfoC93jrNhDq2pljZp8T6ip6vN1/ii6usrqO/3nhM/IyU/jZBU1PnBIRiVxEV+Q6594E3mwy786w6VnArBbWnQ2MbEONgTf91eV8UVnNSzecSmp3XUQtIkdOJ3jHuDeWlvDXxZu58cwhjM7r5Xc5ItLJKfRj2JaKvdzxyjJO6t+TG88a4nc5ItIFKPRjVGOj47aXllDX0MhvLjuJhDj9qkSk7ZQkMWrG/PXMLyzjzgtGMDA71e9yRKSLUOjHoFVfVHLfP1Zz9oi+XDa2f+sriIhESKEfY6rrGrjl+c9IT07gVxefoKtuRaRd6fy/GPPrf65m1Re7eOrqsWT16O53OSLSxWhPP4bMLyzl8bnr+c74fM4c1sfvckSkC1Lox4i9tQ3c9tISBvVO5SfnD/e7HBHpotS9EyP+tKCILRXVvPi9U0lO1OjTIhId2tOPAXtq6vnj++uYNDRb97oVkahS6MeAmR8WUb6nllvP1r1uRSS6FPo+21Vdx2MfrOOsYX0YpbF1RCTKFPo+e2p+ETur6vifL2svX0SiT6Hvo4q9dTw+dx1nj+jLCbkZfpcjIgGg0PfRk/PWs6u6nlu+PNTvUkQkIBT6PtlZVcuMees57/ijOO5o7eWLSMdQ6Pvk8bnr2FNbzy3qyxeRDqTQ90H5nlqeml/EBSOP5tij0vwuR0QCRKHvg//3/lqq6xq4+UvqyxeRjqXQ72Dbd9Uw88MiLjwphyF9evhdjogEjEK/g/3x/bXUNTj+W3v5IuIDhX4H2lpZzZ8XbODiUTm6BaKI+EKh34EefW8tDY2Om87SXr6I+EOh30FKdu7lLx9t5NIxueRlpfhdjogElEK/gzzybiEOxw/OHOJ3KSISYAr9DrCpvIoXF23i8rF55PbSXr6I+Eeh3wEeebcQM+P7Zw72uxQRCTiFfpRtKNvDS58U861xefTLSPa7HBEJOIV+lD08p5D4bsb3J2svX0T8p9CPonXbd/PK4mK+Mz6fPulJfpcjIqLQj6aH5xTQPT6OG7SXLyIxQqEfJYXbdvHqkhKmnjaA7B7d/S5HRARQ6EfNQ+8UkJIQx/WnD/K7FBGR/RT6UbDqi0reWLqFayYMJDM10e9yRET2iyj0zexcM1ttZoVmdnszy/PNbI6ZLTWz98wst8nydDPbbGa/b6/CY9lDswtI6x7PtEkD/S5FROQArYa+mcUBjwDnASOAK8xsRJNmDwDPOOdGAncB9zZZfjfwftvLjX3LNlfwj+Vf8N1JA+mZor18EYktkezpjwMKnXPrnHO1wPPAhU3ajADmeNPvhi83s5OBvsA/215u7HvonQLSk+K5dqL28kUk9kQS+jnAprDHxd68cEuAS7zpi4A0M8sys27Ar4H/bWuhncHS4p28s3Ir158+iPSkBL/LERE5SCShb83Mc00e3wacYWaLgTOAzUA98H3gTefcJg7BzK43s0Vmtmj79u0RlBSbfjN7Db1SErh6gvbyRSQ2xUfQphjoH/Y4FygJb+CcKwEuBjCzHsAlzrkKMzsVmGRm3wd6AIlmtts5d3uT9R8DHgMYM2ZM038oncKyzRW8u3o7Pzz3WHp0j2Szioh0vEjSaSEw1MwGEtqDvxz4VngDM8sGyp1zjcCPgRkAzrlvh7W5GhjTNPC7iifnrSc1MY5vn5LvdykiIi1qtXvHOVcP3Ai8DawEXnTOLTezu8xsitdsMrDazNYQOmh7T5TqjUlfVFTz+pISLhubR0ay+vJFJHZF1A/hnHsTeLPJvDvDpmcBs1p5jqeBpw+7wk5g5odFNDrHNRMG+F2KiMgh6YrcNtpTU8+zCzZw3vH96J+pu2KJSGxT6LfRrE+Kqayu57u6+lZEOgGFfhs0NDqenLee0Xk9GZ3Xy+9yRERapdBvg9krtrKxvIrrJmkkTRHpHBT6bfDkvHX0z0zmnOOO8rsUEZGIKPSP0GebdrKwaAfXThhIXLfmLloWEYk9Cv0j9MTcdaQlxXPpmP6tNxYRiREK/SNQvKOKt5Z9wbdOydOQCyLSqSj0j8DT84sw4OrTBvhdiojIYVHoH6bK6jqeX7iJr47sR7+MZL/LERE5LAr9w/Tiwk3srqln2kSdpikinY9C/zDUNzTy1PwiThmYyQm5GX6XIyJy2BT6h+GtZV+weedeXYwlIp2WQj9CzjmemLuOgdmpnDWsj9/liIgcEYV+hBZt2MGS4gqunTiQbroYS0Q6KYV+hJ6Yu46eKQl8Y3Su36WIiBwxhX4Eikr38M8VW7nylHySE+P8LkdE5Igp9CPw1Pz1JHTrxlWn6f63ItK5KfRbUVFVx4uLiply0tH0SUvyuxwRkTZR6Lfi2Y83sLeugWm6M5aIdAEK/UOorW9k5r+LmDQ0m2FHpftdjohImyn0D+GNpSVsrazhuxO1ly8iXYNCvwWhi7HWM7RPD844prff5YiItAuFfgs+XFfGii2VTJs0EDNdjCUiXYNCvwVPzF1Pdo9ELjwpx+9SRETajUK/GYXbdvOvVdv4zvgBJCXoYiwR6ToU+s14ct56usd348rxeX6XIiLSrhT6TZTtruGvnxZz8ehcsnp097scEZF2pdBv4s8LNlJT36jTNEWkS1Loh6mua+BPC4o4a1gfhvTp4Xc5IiLtTqEf5rXPSijdXcs07eWLSBel0Pc453hi3jqG90vn1MFZfpcjIhIVCn3P6q27WLN1N1edmq+LsUSky1Loe+YVlAIw+VgNuSAiXZdC3zO3oJTBvVPpl5HsdykiIlGj0Adq6hv4eH05k4ZqL19EuraIQt/MzjWz1WZWaGa3N7M838zmmNlSM3vPzHK9+SeZ2Ydmttxbdll7/wDt4dMNO9lb18CEIdl+lyIiElWthr6ZxQGPAOcBI4ArzGxEk2YPAM8450YCdwH3evOrgKucc8cB5wIPmVnP9iq+vcwr3E5cN2P8oEy/SxERiapI9vTHAYXOuXXOuVrgeeDCJm1GAHO86Xf3LXfOrXHOFXjTJcA2IOb6UOYVljGqf0/SkhL8LkVEJKoiCf0cYFPY42JvXrglwCXe9EVAmpkdcLK7mY0DEoG1R1ZqdFRU1fF58U517YhIIEQS+s2dtO6aPL4NOMPMFgNnAJuB+v1PYNYP+BNwjXOu8aAXMLvezBaZ2aLt27dHXHx7+HBdKY0OJg1V6ItI1xdJ6BcD/cMe5wIl4Q2ccyXOuYudc6OAO7x5FQBmlg78Hfipc25Bcy/gnHvMOTfGOTemd++O7f2ZW1BKj+7xnNg/5g41iIi0u0hCfyEw1MwGmlkicDnwWngDM8s2s33P9WNghjc/EXiF0EHel9qv7PYzr7CU8YMySYjT2asi0vW1mnTOuXrgRuBtYCXwonNuuZndZWZTvGaTgdVmtgboC9zjzf8mcDpwtZl95n2d1N4/xJHaVF7FhrIqJqo/X0QCIj6SRs65N4E3m8y7M2x6FjCrmfX+DPy5jTVGzbzC0NALE9WfLyIBEeg+jXkFpRyVnsTg3ho7X0SCIbCh39jomL+2lIlDszWqpogERmBDf3lJJTur6tSfLyKBEtjQn1sYuh5AF2WJSJAENvTnF5Yy7Kg0eqd197sUEZEOE8jQr65rYGHRDnXtiEjgBDL0P15fTm19o07VFJHACWTozy8sJTGuG+MGaihlEQmWQIb+3IJSRuf3JCUxomvTRES6jMCFfunuGlZsqdStEUUkkAIX+v9eWwboVE0RCabAhf68gu1kJCdwQk6G36WIiHS4QIW+c455BaWcNjiLuG4aekFEgidQob++dA8lFdXq2hGRwApU6O8bSlm3RhSRoApU6M8tKKV/ZjL5Wal+lyIi4ovAhH59QyML1pZp6AURCbTAhP6S4gp21dQzcYjOzxeR4ApM6M8rKMUMThuc5XcpIiK+CUzozy8s5fijM+iVmuh3KSIivglE6O+uqefTjTs0qqaIBF4gQv+jdWXUNzodxBWRwAtE6M8rLKV7fDdOzu/ldykiIr4KRugXlDJuYCZJCXF+lyIi4qsuH/pbK6sp2LZbXTsiIgQg9OcVhIZe0EFcEZEghH5hKVmpiQw/Kt3vUkREfNelQ985x7zCUk4bkk03DaUsItK1Q3/N1t1s31XDJPXni4gAXTz05xZsB2CC+vNFRIAuHvrzC0sZlJ1KTs9kv0sREYkJXTb0a+sb+Wh9uc7aEREJ02VD/9ONO6iqbdD5+SIiYbps6M8vLCWumzFeQymLiOzXZUN/bkEpJ+ZmkJ6U4HcpIiIxI6LQN7NzzWy1mRWa2e3NLM83szlmttTM3jOz3LBlU82swPua2p7Ft6Siqo6lxTvVtSMi0kSroW9mccAjwHnACOAKMxvRpNkDwDPOuZHAXcC93rqZwM+BU4BxwM/NLOpDXX64roxGBxOH6taIIiLhItnTHwcUOufWOedqgeeBC5u0GQHM8abfDVv+FWC2c67cObcDmA2c2/ayD21e4XZSE+MYldcz2i8lItKpRBL6OcCmsMfF3rxwS4BLvOmLgDQzy4pw3XY3r6CUUwZlkRDXZQ9ZiIgckUhSsblBa1yTx7cBZ5jZYuAMYDNQH+G6mNn1ZrbIzBZt3749gpJatqm8iqKyKvXni4g0I5LQLwb6hz3OBUrCGzjnSpxzFzvnRgF3ePMqIlnXa/uYc26Mc25M795t64efXxgaSnmSLsoSETlIJKG/EBhqZgPNLBG4HHgtvIGZZZvZvuf6MTDDm34bOMfMenkHcM/x5kXN3MJS+qZ3Z0ifHtF8GRGRTqnV0HfO1QM3EgrrlcCLzrnlZnaXmU3xmk0GVpvZGqAvcI+3bjlwN6F/HAuBu7x5UdHY6Ph3YSkThmRjpqGURUSaio+kkXPuTeDNJvPuDJueBcxqYd0Z/GfPP6pWbKlkR1WdunZERFrQpU5vmevdGnHCYIW+iEhzulTozy8s5di+afRJT/K7FBGRmNRlQr+6roGPi8qZoFM1RURa1GVCv7K6jnOPO4ovD+/jdykiIjErogO5nUGftCQevmKU32WIiMS0LrOnLyIirVPoi4gEiEJfRCRAFPoiIgGi0BcRCRCFvohxgTesAAAGpUlEQVRIgCj0RUQCRKEvIhIg5txBN7LylZltBza04SmygdJ2KicaVF/bqL62UX1tE8v15TvnWr0LVcyFfluZ2SLn3Bi/62iJ6msb1dc2qq9tYr2+SKh7R0QkQBT6IiIB0hVD/zG/C2iF6msb1dc2qq9tYr2+VnW5Pn0REWlZV9zTFxGRFnTK0Dezc81stZkVmtntzSzvbmYveMs/MrMBHVhbfzN718xWmtlyM7u5mTaTzazCzD7zvu5s7rmiXGeRmX3uvf6iZpabmT3sbcOlZja6A2s7NmzbfGZmlWZ2S5M2HboNzWyGmW0zs2Vh8zLNbLaZFXjfe7Ww7lSvTYGZTe3A+u43s1Xe7+8VM+vZwrqHfC9Esb7pZrY57Hd4fgvrHvLvPYr1vRBWW5GZfdbCulHffu3KOdepvoA4YC0wCEgElgAjmrT5PvBHb/py4IUOrK8fMNqbTgPWNFPfZOANn7djEZB9iOXnA28BBowHPvLx9/0FoXOQfduGwOnAaGBZ2Lz7gNu96duB/2tmvUxgnfe9lzfdq4PqOweI96b/r7n6InkvRLG+6cBtEfz+D/n3Hq36miz/NXCnX9uvPb86457+OKDQObfOOVcLPA9c2KTNhcBMb3oW8CUzs44ozjm3xTn3qTe9C1gJ5HTEa7ezC4FnXMgCoKeZ9fOhji8Ba51zbblgr82ccx8A5U1mh7/PZgJfb2bVrwCznXPlzrkdwGzg3I6ozzn3T+dcvfdwAZDb3q8bqRa2XyQi+Xtvs0PV52XHN4Hn2vt1/dAZQz8H2BT2uJiDQ3V/G+9NXwFkdUh1YbxupVHAR80sPtXMlpjZW2Z2XIcWFuKAf5rZJ2Z2fTPLI9nOHeFyWv5j83sb9nXObYHQP3uguRs0x8p2vJbQJ7fmtPZeiKYbve6nGS10j8XC9psEbHXOFbSw3M/td9g6Y+g3t8fe9BSkSNpElZn1AF4GbnHOVTZZ/Cmh7ooTgd8Bf+vI2jwTnHOjgfOAH5jZ6U2Wx8I2TASmAC81szgWtmEkYmE73gHUA8+20KS190K0PAoMBk4CthDqQmnK9+0HXMGh9/L92n5HpDOGfjHQP+xxLlDSUhsziwcyOLKPlkfEzBIIBf6zzrm/Nl3unKt0zu32pt8EEswsu6Pq8163xPu+DXiF0MfocJFs52g7D/jUObe16YJY2IbA1n1dXt73bc208XU7egeOLwC+7bwO6KYieC9EhXNuq3OuwTnXCDzewuv6vf3igYuBF1pq49f2O1KdMfQXAkPNbKC3J3g58FqTNq8B+86S+Abwr5be8O3N6/97EljpnHuwhTZH7TvGYGbjCP0eyjqiPu81U80sbd80oQN+y5o0ew24yjuLZzxQsa8rowO1uIfl9zb0hL/PpgKvNtPmbeAcM+vldV+c482LOjM7F/gRMMU5V9VCm0jeC9GqL/wY0UUtvG4kf+/R9GVglXOuuLmFfm6/I+b3keQj+SJ0ZskaQkf17/Dm3UXozQ2QRKhLoBD4GBjUgbVNJPTxcynwmfd1PnADcIPX5kZgOaEzERYAp3Xw9hvkvfYSr4592zC8RgMe8bbx58CYDq4xhVCIZ4TN820bEvrnswWoI7T3+V1Cx4nmAAXe90yv7RjgibB1r/Xei4XANR1YXyGh/vB978N9Z7QdDbx5qPdCB9X3J++9tZRQkPdrWp/3+KC/946oz5v/9L73XFjbDt9+7fmlK3JFRAKkM3bviIjIEVLoi4gEiEJfRCRAFPoiIgGi0BcRCRCFvgSWmf37MNtPNrM3olWPSEdQ6EtgOedO87sGkY6m0JfAMrPd3vfJZvaemc3yxp9/Nuxq33O9efMIXY6/b91Ub5CwhWa22Mwu9ObfamYzvOkTzGyZmaX48OOJNEuhLxIyCrgFGEHoKssJZpZEaEyYrxEaafGosPZ3EBreYyxwJnC/dxn+Q8AQM7sIeAr4nmthCAQRPyj0RUI+ds4Vu9DgX58BA4BhwHrnXIELXbr+57D25wC3e3dTeo/Q0B953vpXExpi4H3n3PyO+xFEWhfvdwEiMaImbLqB//xttDROiQGXOOdWN7NsKLCb0BgtIjFFe/oiLVsFDDSzwd7jK8KWvQ3cFNb3P8r7ngH8ltDt97LM7BsdWK9IqxT6Ii1wzlUD1wN/9w7kht+y8W4gAVjq3Uz7bm/+b4A/OOfWEBpJ8ldm1twdtUR8oVE2RUQCRHv6IiIBotAXEQkQhb6ISIAo9EVEAkShLyISIAp9EZEAUeiLiASIQl9EJED+PwzKn1mh7judAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "frame.reset_index().plot('index', 'binary_acc')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Replay a Trial\n", "\n", "One of the perks of history is the ability to replay a trial. We'll look at two of the replay options here. First we can just replay the whole training process, this time with a different verbosity." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6eef5ee55c334f119d8a60fd2af57be1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2008f381538142618cff8fc050eb307f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "142a32f36308426e9fe59ad5e2092701", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c8c7783d52ff45dca7874f4f728d0094", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='3/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15261a70fcb8445d84aae00e59097b9b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='4/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31d57a94b04d47e08bf29c14dada538a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='5/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1d2d2e2eba8542faba70141ba4302ba4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='6/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9aa166f3fe2940c3b652169bb00020d7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='7/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73436eeb6a6a46bd8086400d87f93624", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='8/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fee3517b93b1460bbab3997cc07f747d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='9/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7575c42b2b484e369b6d7cf58e8b7024", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='10/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "221c36be827a405b91b62e7d6cc748cc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='11/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b7869ba8ae444650a8ca9856610f3a75", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='12/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ae4218c4a8c4ad3b4e83a891cdb798d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='13/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1974270d05414437b5345804820abcda", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='14/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7a7899659a7d46beaa9a918a2886f3c6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='15/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "211dc56c1dae4c57af5bdca3393dc7c4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='16/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c92266f9d924abebdc66383f90b3e28", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='17/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fcda5e5aca274e5c8b8d17aa526c9a49", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='18/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "efbd88786f9649108315bb275f1e21ea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='19/20(t)', max=16), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "_ = trial.replay(verbose=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This may be more output than we desire, and so we can instead use the `one_batch` flag to just simulate one batch per epoch." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "af55bc1e2eba480f8c288d9d17d20e67", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c96ab8cd8e7e41babd0b326ec8f13d5a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15e38957d69e4218947f7e64bc3a5860", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4fea98cef5ea40ccb04fa47ee0dc5cd5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='3/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "13f62e9db6d942e9995f903ab796c72d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='4/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8dfb537c62a749f58b075ba5e42861ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='5/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fd9c4784b6574fec8c549838e768a28e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='6/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48626c90db0843c7b59164db6d008cf7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='7/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e2ca960b335344528dd3ae61aa823bc4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='8/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5b2dbd82198f4df78a3538a85d1a3205", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='9/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e47dd50497714db187edf8996de1515e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='10/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29211e3e9ac8484a958dd85a9e81a671", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='11/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e4cd52e99a1f4451ae13cdef4a007c27", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='12/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f7c4c301a0c74de3bd1103bd0648054b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='13/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6102be2e3194d94a8926137284ab1a6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='14/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ea8b0d4690f48468e07aa2ba8917dc4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='15/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da98b8d23daf461197613f8ce547dcc7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='16/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c291f91c91ed465ab55fec1ebcf07ea1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='17/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a37c3f0a2b14538a6e5be77b61e9813", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='18/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e0fbef3364446c7ac4cb68971f3abfc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='19/20(t)', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "_ = trial.replay(verbose=2, one_batch=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So that's history and replaying in `torchbearer`. Be sure to have a look at our other examples at [pytorchbearer.org](http://www.pytorchbearer.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }