{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"#skip\n",
"! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai.basics import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp callback.progress"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Progress and logging callbacks\n",
"\n",
"> Callback and helper function to track progress of training or log results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.test_utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ProgressCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"@docs\n",
"class ProgressCallback(Callback):\n",
" \"A `Callback` to handle the display of progress bars\"\n",
" order,_stateattrs = 60,('mbar','pbar')\n",
"\n",
" def before_fit(self):\n",
" assert hasattr(self.learn, 'recorder')\n",
" if self.create_mbar: self.mbar = master_bar(list(range(self.n_epoch)))\n",
" if self.learn.logger != noop:\n",
" self.old_logger,self.learn.logger = self.logger,self._write_stats\n",
" self._write_stats(self.recorder.metric_names)\n",
" else: self.old_logger = noop\n",
"\n",
" def before_epoch(self):\n",
" if getattr(self, 'mbar', False): self.mbar.update(self.epoch)\n",
"\n",
" def before_train(self): self._launch_pbar()\n",
" def before_validate(self): self._launch_pbar()\n",
" def after_train(self): self.pbar.on_iter_end()\n",
" def after_validate(self): self.pbar.on_iter_end()\n",
" def after_batch(self):\n",
" self.pbar.update(self.iter+1)\n",
" if hasattr(self, 'smooth_loss'): self.pbar.comment = f'{self.smooth_loss:.4f}'\n",
"\n",
" def _launch_pbar(self):\n",
" self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)\n",
" self.pbar.update(0)\n",
"\n",
" def after_fit(self):\n",
" if getattr(self, 'mbar', False):\n",
" self.mbar.on_iter_end()\n",
" delattr(self, 'mbar')\n",
" if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger\n",
"\n",
" def _write_stats(self, log):\n",
" if getattr(self, 'mbar', False): self.mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in log], table=True)\n",
"\n",
" _docs = dict(before_fit=\"Setup the master bar over the epochs\",\n",
" before_epoch=\"Update the master bar\",\n",
" before_train=\"Launch a progress bar over the training dataloader\",\n",
" before_validate=\"Launch a progress bar over the validation dataloader\",\n",
" after_train=\"Close the progress bar over the training dataloader\",\n",
" after_validate=\"Close the progress bar over the validation dataloader\",\n",
" after_batch=\"Update the current progress bar\",\n",
" after_fit=\"Close the master bar\")\n",
"\n",
"if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback, Recorder, ProgressCallback]\n",
"elif ProgressCallback not in defaults.callbacks: defaults.callbacks.append(ProgressCallback)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 8.960221 | \n",
" 8.501486 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 7.650368 | \n",
" 5.475908 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 6.193127 | \n",
" 3.202425 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 3 | \n",
" 4.902714 | \n",
" 1.781969 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.847687 | \n",
" 0.968699 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@patch\n",
"@contextmanager\n",
"def no_bar(self:Learner):\n",
" \"Context manager that deactivates the use of progress bars\"\n",
" has_progress = hasattr(self, 'progress')\n",
" if has_progress: self.remove_cb(self.progress)\n",
" try: yield self\n",
" finally:\n",
" if has_progress: self.add_cb(ProgressCallback())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 16.774219512939453, 16.614517211914062, '00:00']\n",
"[1, 14.62364387512207, 11.538640975952148, '00:00']\n",
"[2, 12.198295593261719, 7.462512016296387, '00:00']\n",
"[3, 9.962362289428711, 4.619643688201904, '00:00']\n",
"[4, 8.045241355895996, 2.791717052459717, '00:00']\n"
]
}
],
"source": [
"learn = synth_learner()\n",
"with learn.no_bar(): learn.fit(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"#Check validate works without any training\n",
"def tst_metric(out, targ): return F.mse_loss(out, targ)\n",
"learn = synth_learner(n_trn=5, metrics=tst_metric)\n",
"preds,targs = learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"#Check get_preds works without any training\n",
"learn = synth_learner(n_trn=5, metrics=tst_metric)\n",
"preds,targs = learn.validate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.before_fit
()\n",
"\n",
"Setup the master bar over the epochs"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.before_fit)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.before_epoch
()\n",
"\n",
"Update the master bar"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.before_epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.before_train
()\n",
"\n",
"Launch a progress bar over the training dataloader"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.before_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.before_validate
()\n",
"\n",
"Launch a progress bar over the validation dataloader"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.before_validate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.after_batch
()\n",
"\n",
"Update the current progress bar"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.after_batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.after_train
()\n",
"\n",
"Close the progress bar over the training dataloader"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.after_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.after_validate
()\n",
"\n",
"Close the progress bar over the validation dataloader"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.after_validate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ProgressCallback.after_fit
()\n",
"\n",
"Close the master bar"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ProgressCallback.after_fit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ShowGraphCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class ShowGraphCallback(Callback):\n",
" \"Update a graph of training and validation loss\"\n",
" order,run_valid=65,False\n",
"\n",
" def before_fit(self):\n",
" self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, \"gather_preds\")\n",
" if not(self.run): return\n",
" self.nb_batches = []\n",
" assert hasattr(self.learn, 'progress')\n",
"\n",
" def after_train(self): self.nb_batches.append(self.train_iter)\n",
"\n",
" def after_epoch(self):\n",
" \"Plot validation loss in the pbar graph\"\n",
" if not self.nb_batches: return\n",
" rec = self.learn.recorder\n",
" iters = range_of(rec.losses)\n",
" val_losses = [v[1] for v in rec.values]\n",
" x_bounds = (0, (self.n_epoch - len(self.nb_batches)) * self.nb_batches[0] + len(rec.losses))\n",
" y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(val_losses)))))\n",
" self.progress.mbar.update_graph([(iters, rec.losses), (self.nb_batches, val_losses)], x_bounds, y_bounds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 11.676161 | \n",
" 8.622957 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 10.183996 | \n",
" 6.069672 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 8.519948 | \n",
" 3.890609 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 3 | \n",
" 6.959560 | \n",
" 2.382288 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 4 | \n",
" 5.621220 | \n",
" 1.414858 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsOElEQVR4nO3dd3zV1f3H8dfJHmQvQkIIm0AmCUOWDEU2WBCp2loHWLe2trVa21prf3ZZt9ZVBzgQRIYMBZmyAwECYZOQPQkJZCfn98c3yJAAWfeb3Pt5Ph73kdzxvffD9wFvj+ee7+corTVCCCGsi53ZBQghhGh5Eu5CCGGFJNyFEMIKSbgLIYQVknAXQggr5GDJD/P399fh4eEXPVZaUU1qYdlFj9krhaODHU72drg42hHo4YJSFixUCCHakMTExAKtdUBjjrFouIeHh7Nz584f7lfV1DHupQ2EKPjnjBiyisvJLC4n85TxM+NUGYdzzzAsthP/mRmLnZ0kvBDC9iil0hp7jEXD/VJzt6ZxvOAs7/8igfguPsR38fnRa15fe5R/rjqEfwdn/jAxAiVDeCGEuCrTwr24rIqX1xxheE9/RvUObPB1D4zsTn5pJe9tOkGghzP3Xd/dglUKIUT7ZFq4v7T6CKUV1Tx9ldG4Uoo/TupL/plK/m/FQfw7ODM9PtSClQohRPtjSrgfyz/D3K1p3DogjD4dPa/6ejs7xYszYyguq+K3C/fi28HpiqN9IYT1qK6uJiMjg4qKCrNLaXUuLi6Ehobi6OjY7PcyJdz/9nUKLo72/OrGXtd8jLODPW/dEc+st7fywNxdfDJ7EHFhP56jF0JYl4yMDDw8PAgPD7fq79y01hQWFpKRkUHXrl2b/X4WX+e+6UgBaw7m8eCoHgR4ODfqWA8XRz64ayABHs7c/cEOjuWfaaUqhRBtRUVFBX5+flYd7GBMQfv5+bXY/6FYPNz/+vUBQn1cuWtoeJOOD/Bw5qO7B2Jvp7j1v1vYfKygZQsUQrQ51h7s57Tkn9Oi4V50toqDOaX8fnwELo72TX6fcH93PpszGC9XR+54dxtvrjuGtC4WQojzLBruuSUVDAj3YUJUx2a/V49ADxY/NIzxkcH8feVB7vs4kZKK6haoUgghzisuLuaNN95o9HETJkyguLi45Qu6RlcNd6XU+0qpPKVU8gWP/VMpdVAptVcptUgp5X0tH1ZTp/nDxL4t9r8eHZwdeO22OJ6Z1JfvDuYx5dVNpGSXtMh7CyEENBzutbW1Vzxu+fLleHt7t1JVV3ctI/cPgHGXPPYtEKm1jgYOA7+/lg/zdnMkprN3Y+q7KqUU9wzryqdzBlNWVcvNb3zPl7syWvQzhBC268knn+TYsWPExsYyYMAARo0axW233UZUVBQA06ZNIz4+nn79+vH222//cFx4eDgFBQWkpqYSERHB7Nmz6devH2PHjqW8vLzV677qUkit9QalVPglj31zwd2twIxr+bAQb9dGFdcYA8J9WfbIMB7+ZDe/mr+HTUcL+NOkfni5NX+9qBCibXh26X4OZLXs/5337eTJnyb3a/D5F154geTkZJKSkli3bh0TJ04kOTn5h+WK77//Pr6+vpSXlzNgwACmT5+On5/fRe9x5MgRPv30U9555x1mzpzJwoULueOOO1r0z3GplphzvxtY0dCTSqk5SqmdSqmdhQWtu7Il0MOFefcO4uHRPViclMWN/1nP6gO5rfqZQgjbMnDgwIvWob/yyivExMQwePBg0tPTOXLkyI+O6dq1K7GxsQDEx8eTmpra6nU26yImpdTTQA0wr6HXaK3fBt4GSEhIaPUlLQ72dvx6bG/G9u3Ibxbs4d6PdvKTuBD+OLkv3m5O1/QeWmuO5Z9hR+opdqQWsSe9mD7Bnjw6pie9gjxa+U8ghGjIlUbYluLu7v7D7+vWrWP16tVs2bIFNzc3Ro4cedl16s7O56/psbe3bxvTMg1RSt0JTALG6Da4DjEq1IslDw3j9bVHeX3tUTYeLeCv0yK5qd/FK3XKqmrIPFVORnE5h3JK2ZlaRGLaKU6VGStv/NydiAr1Yt3BPJbvy2ZiVDCPjulJTwl5IWyCh4cHpaWll33u9OnT+Pj44ObmxsGDB9m6dauFq2tYk8JdKTUO+B1wvda67GqvN4uTgx2P39iLsf2CeOKLvdz3cSKjegfg5GD3Q9/4cyF+Tjd/d26ICGJAuC8J4T509XdHKcWps1W8s/E4H25O5et92UyO7sQjY3rSI7CDSX86IYQl+Pn5MXToUCIjI3F1dSUoKOiH58aNG8dbb71FdHQ0vXv3ZvDgwSZWejF1tUG3UupTYCTgD+QCf8JYHeMMFNa/bKvW+pdX+7CEhAR94WYdllRdW8cba4/x8dY0fNwcCfFxJcTb9fxPb1fC/d3x73DllghFF4R8RXUtE6KC6dvJEz93J3zcnPDr4ISvuzO+bk54ujrYzJV1QrSWlJQUIiIizC7DYi7351VKJWqtExrzPlcN95ZkZri3tMIzlbyz8QSfbEujpKLmsq/p5OXCP2bEMKynv4WrE8J6SLg3LdxN3YmpPfPr4MyT4/vw5Pg+lFfVUlRWRdGZKgrPVlJ0toqis1V8tiOdO97bxuzhXXnipt44OzS95YIQQjSGhHsLcHWyJ8TJ9Ufr+G8f1IW/LU/hnY0n2HS0kFdmxcoXsUIIi7B4V0hb4upkz3PTInnvzgTySiqY9OomPt6SKk3OhBCtTsLdAsZEBLHiseEM7ubHM4v3c8+HO8kvrTS7LCGEFZNwt5BADxc+uGsAf57cl01HC7jhxfV8su0kdXUyihdCtDwJdwtSSvGLoV1Z/sgw+nT04KlF+5jx1uYW75UhhDBPhw7GtS9ZWVnMmHH5tlsjR46ktVcOSriboEegB5/NGcyLM2NIKyxj8mub+OuyA5ypvPySSiFE+9OpUycWLFhg2udLuJtEKcVP+oey5tfXc+uAzrz3/Qlu+Pd6VuzLli9chWhDfve7313Uz/3Pf/4zzz77LGPGjKF///5ERUWxePHiHx2XmppKZGQkAOXl5cyaNYvo6GhuvfXWtt1bRrQMbzcn/nZzFDPiQ3l6UTL3z9tFmK8bE6ODmRQdTN9gT7nKVYhzVjwJOfta9j07RsH4Fxp8etasWTz22GM88MADAMyfP5+VK1fy+OOP4+npSUFBAYMHD2bKlCkN/lt98803cXNzY+/evezdu5f+/fu37J/hMiTc24j+YT4sfWgoi5OyWLwni7c3HOfNdcfoFuDOpOhOTI4OljXyQpggLi6OvLw8srKyyM/Px8fHh+DgYB5//HE2bNiAnZ0dmZmZ5Obm0rHj5bcQ3bBhA4888ggA0dHRREdHt3rdEu5tiIO9HdPjQ5keH0rR2SpWJuewdE8Wr313hFfWHKFPRw+mxHZiSkwnQn3czC5XCMu7wgi7Nc2YMYMFCxaQk5PDrFmzmDdvHvn5+SQmJuLo6Eh4ePhlW/1eyNL/By7h3kb5ujtx26AwbhsURl5pBSv25bBkTxb/WHmIf6w8xMBwX6bEdmJiVDA+7tfWp14I0TSzZs1i9uzZFBQUsH79eubPn09gYCCOjo6sXbuWtLS0Kx4/YsQI5s2bx6hRo0hOTmbv3r2tXrOEezsQ6OHCnUPCuXNIOOlFZSxOyuSrpCz+8FUyzy7dz/W9ApgWF8INEUG4OEr/GiFaWr9+/SgtLSUkJITg4GBuv/12Jk+eTEJCArGxsfTp0+eKx99///3cddddREdHExsby8CBA1u9ZukK2U5prdmfVcLipEyW7Mkit6SSDs4OjI/syM1xIQzq5oe9nXwRK9o/6QopXSFtilKKyBAvIkO8eHJ8BNuOF7JodyYrknP4IjGDjp4uTI3rxNSYEEJ86hua6XM/jF+cHOxwc5K/AkJYI/mXbQXs7RRDevgzpIc/z02LZHVKLl/tzuS9jSf47/rjDR5np2BErwBmxIfKlI4QVkbC3cq4ONozKboTk6I7UXS2itUpuZTWbyZy4SSNUpBTUsGSpCwe+mQ3ni4OTI7pxPT4UOI6e8vaetGmaK1t4u9kS06Ty5y7jaut02w5VsiCxHRW7s+horqObgHuzIgPZXr/UII8XcwuUdi4EydO4OHhgZ+fn1UHvNaawsJCSktL6dq160XPyTZ7ollKK6pZvi+bBYkZ7Eg9hZ2Ckb0DmZnQmdF9AnFykG4VwvKqq6vJyMi46jpya+Di4kJoaCiOjo4XPS7hLlrMiYKzfLEznQWJGeSVVuLn7sTNcSHMHNCZXnKlrBAWJeEuWlxNbR0bjxTw+Y50VqfkUlOnGRDuwz3DunJj346y3FIIC5BwF62q8EwlX+7K5MMtqWScKifUx5VfDAln5oDOeLo4Xv0NhBBNIuEuLKK2TvPtgRze35TK9tQi3J3suSWhM78YEk64v7vZ5QlhdSTchcXtyzjN/74/wdK9WVTXakK8XYnp7EVUiDcxoV5EhnrJqF6IZpJwF6bJK6lgyZ4sktKL2ZtxmpNFZT88183fneE9/fnV2N54uUrQC9FYrdJ+QCn1PjAJyNNaR9Y/5gt8DoQDqcBMrfWpxhYsrEegpwv3Du/2w/1TZ6vYl3mavRnFJKWfZu62k6zan8s/ZkQzoleAiZUKYRuuZeHyB8C4Sx57Elijte4JrKm/L8QPfNydGNErgIdG9+TdOxP48v4hdHBx4Ofvb+fpRfs4K/vFCtGqrhruWusNQNElD08FPqz//UNgWsuWJaxNTGdvlj08jDkjuvHJ9pOMf3kj244Xml2WEFarqZccBmmtswHqfwY29EKl1Byl1E6l1M78/PwmfpywBi6O9jw1IYL5910HwKx3tvLcsgOUVckoXoiWdk1fqCqlwoFlF8y5F2utvS94/pTW2udq7yNfqIpzzlbW8MKKg3y8NQ1nBzuu6+7H9b0CGNk7kK6ynFKIi7TaapnLhPshYKTWOlspFQys01r3vtr7SLiLSyWmFbF0TzbrD+dzouAsAF383Li+VwCjegcytIe/9LQRNs+Sm3UsAe4EXqj/ubiJ7yNsXHwXX+K7+AKQVniW9YfzWXcon/k70/loSxq+9T1tbkkIpU9HT5OrFaL9uOrIXSn1KTAS8AdygT8BXwHzgTDgJHCL1vrSL11/REbu4lpVVNey+VgBCxIz+PZALtW1muhQL25J6MyUmE6yXl7YFLmISVilorNVfLU7k/k70zmYU4qzgx0TooKZM6IbEcEymhfWT8JdWLVzm4J/viOdRbszOVNZw+g+gdw/sjsDwn3NLk+IViPhLmzG6bJqPtqSyv82p1J0toqB4b7cP6o7I3sFWPVuPcI2SbgLm1NWVcPnO9J5Z8Nxsk5XEBHsyS3xofTt5ElEsKfMzQurIOEubFZVTR2LkzL574bjHM0788PjId6uPwR9v06eXN8rABdHexMrFaLxLLkUUog2xcnBjlsSOjMjPpT80kr2Z5eQkl3CgSzj55qUXOq0EfZPju/DpOhgmb4RVk1G7sImlFfVsu1EIX9feYiU7BISuvjwzKS+xHT2Nrs0Ia6qKSN3ufRP2ARXJ3tG9g5k2cPD+Pv0KFILy5j6+vf86vMksk+Xm12eEC1ORu7CJp2prOGNtUd5d9MJ7BTMHt6NW+I7E+bnZnZpQvyIfKEqRCOlF5XxwsqDfL03G4DIEE/GRwYzPrIj3QI6mFydEAYJdyGaKONUGSuTc/h6Xza7TxYD0KejBxOigpkWGyIjemEqCXchWkBWcTkrk3NYkZzNzrRT2CnFzITOPDqmJx29XMwuT9ggCXchWlj26XL+u/4487alYacUdw4J5/7ru+Pj7mR2acKGSLgL0UrSi8r4z+rDLNqdibuTA7OHd+Oe4V3p4CyXiojWJ+EuRCs7nFvKv1Yd4psDufi5OzEjIZSb+nUkNtQbOzu5KEq0Dgl3ISxk98lTvLLmCBuPFFBTpwnydObGvkGM6xfMoG6+ONrLJSSi5Ui4C2Fhp8uq+e5QLquSc1l/OJ/y6lq8XB25ISKIXwwJJyrUy+wShRWQcG8vMhLBtyu4SQ9ya1JeVcvGI/ms2p/LN/tzKK2sYXhPf+4f2Z3ruvlJLxvRZBLu7UFtNbwcC7VVMOGf0HcqyD96q1NSUc28rSd5b9MJCs5UEtPZmwdGdufGiCCZmxeNJuHeXmTvhSUPQfYe6DMJJvwLPIPNrkq0gorqWhbuyuC/649zsqiM7gHu/PL67kyLC5F5eXHNJNzbk9oa2Po6rP0b2DvD2Oeg/89lFG+lamrrWJ6cw5vrjpGSXUKItyuzh3fl1gFhuDpJf3lxZRLu7VHhMVjyCKRtgq4jYPLL4NvN7KpEK9Fas+5wPm+sPcqO1FP4uTtx97Cu/Oy6Lni6yK5R4vIk3NurujrY9SF8+0djTn700zD4AbCTEZ01236iiDfWHWXdoXw8nB342XVduGtoVwI8nM0uTbQxEu7tXUkWLPsVHF4BnfrD1NcgqJ/ZVYlWlpx5mjfXH2P5vmwc7eyYFteJe4d3o1eQh9mliTZCwt0aaA37v4Tlv4WKYhj2KxjxBDjIaM7anSg4y/ubTvBFYjoV1XWM6BXA7OFdGdbDX5ZR2jgJd2tSVgQrfw97PwP/3jDlVQgbZHZVwgJOna3ik+0n+WBzKvmllfQO8uCe4V2ZEtNJNve2URYPd6XU48C9gAb2AXdprSsaer2EexMcWQ3LHoPTGTDoPhj9DDjLJhK2oLKmlqV7snl343EO5pTi7ebIjP6h3D64C1393c0uT1iQRcNdKRUCbAL6aq3LlVLzgeVa6w8aOkbCvYkqS2HNX2D7O+DVGSa/BD3GmF2VsBCtNVuOFzJv60lW7c+hpk4zrIc/dwwO44aIIBxkvbzVMyPctwIxQAnwFfCK1vqbho6RcG+mk1thycNQcBhiboObnpcWBjYmr6SCz3ek8+n2k2SdriDI05mfDgzjtoFhBHrKRiLWyoxpmUeB54Fy4But9e2Xec0cYA5AWFhYfFpaWpM/TwDVFbDhn/D9S+DqU9/CYJpc/GRjamrrWHson7lb01h/OB8HO8WEqGDuHBJO/zBv+QLWylh65O4DLARuBYqBL4AFWuu5DR0jI/cWlLMPFj8E2UnQeyJM/Le0MLBRqQVn+WhLGl/sTKe0soaoEC/uHBLOpOhg+QLWSlg63G8Bxmmt76m//3NgsNb6gYaOkXBvYbU1sPUNWPu8tDAQnK2s4cvdmXy0OZUjeWfwdXfipwM7c/ugLnTydjW7PNEMlg73QcD7wACMaZkPgJ1a61cbOkbCvZUUHoOlj0LqRggfbrQw8OtudlXCJFprNh8r5IPNqaxJyUUpxdi+Qfz8unAGd/OVKZt2yIw592cxpmVqgN3AvVrryoZeL+HeirQ2Whh884zRwmDUU0YLA3vZ49OWpReVMXdbGp/vSKe4rJreQR78fEgXbo4Lwc1J/m60F3IRkzBaGHz9azi0HDrFwZTXoGOk2VUJk1VU17IkKYsPNqdyILsEDxcHpsWGcEtCKFEhXjKab+Mk3IVBa9i/CFb8FspPwbDHYcRvpIWBQGtNYtopPtqSxsr9OVTV1NE7yINbEkK5OS4Evw7yd6QtknAXFysrglVPwZ5Pwb+XMYqXFgai3unyapbuyeKLxAz2pBfjYKcY3SeQ6fGhjOgZIH3m2xAJd3F5R1fD0seMFgYD58CYP0oLA3GRw7mlfLEznUW7Myk4U4Wzgx3XdfdjdJ9ARvUOpLOvm9kl2jQJd9GwyjPw3XOw7b/gFQqTXoKeN5hdlWhjqmvr2Ha8iO8O5vHdwVxSC8sA6BXUgVF9AhkfGUxMqMzRW5qEu7i6k9uM/VsLDkP0LBj3f9LCQDToeP4ZvjuYx9pDeWw/UUR1rSYi2JPbBnZmalyI7B5lIRLu4trUVMKGf8GmF8HF22hh0O9mufhJXFFJRTXL9mTzyfY0kjNLcHG0Y3J0J346KIy4ztLyoDVJuIvGyUk2RvFZu6H3hPoWBp3Mrkq0A/syTvPJ9jQWJ2VRVlVLn44eTI7pxNi+QfQI7CBB38Ik3EXj1dbAtjfhu+fB3hFu/Av0vxPspI2suLozlTUsScri8x0n2ZNxGoBwPzduiAjixr5BxHfxkZbELUDCXTRd0XFY8oi0MBBNln26nNUpeaw+kMuWY4VU1dbh4+bI6D5BzIgPldYHzSDhLppHa9j9Maz6A9RW1rcweFBaGIhGO1NZw4bD+Xx7IJfVKbmUVtQQ7ufGrQPCmB4fQqCH9J5vDAl30TJKsmH5E3BwGQTHwtTXoGOU2VWJdqq8qpYVydl8tiOd7SeKsLdTjOkTyE8HhjGiVwD2djKavxoJd9FytIYDi42QLz8FQx8zWhg4yohLNN2x/DPM35HOgsQMCs9W0dHThRnxocxM6EyYn1wo1RAJd9Hyyorgmz9A0rz6FgavQthgs6sS7VxVTR1rUnL5fGc6Gw7nU6fhum5+3DqgM+MiO8omI5eQcBet5+gaWPYYFKfDwNn1LQw8zK5KWIHs0+Us2JnB/MR00ovK8XRxYGpsCFNjOxEX5iPTNki4i9ZWeQa++ytsews8Q2DyS9DzRrOrElairk6z9UQh83eksyI5h8qaOvw7OP2wrHJoD3+bHdFLuAvLSN8OSx6G/IPSwkC0itKKatYdyuebA7msO5hHaWUNbk72jOgZwNh+Rth72FDrAwl3YTk1lbDx38bNxRsm/AP6/URaGIgWV1VTx9bjhXxzIIdvD+SSW1KJs4MdN/YN4ua4EEb0CsDRyi+UknAXlpe7HxY/BFm7oNd4o4WBV4jZVQkrVVen2Z1ezOKkTJbuyeJUWTW+7k5Mig5mWlyI1fa4kXAX5qirha1vGvPx9o5w47PQ/xfSwkC0quraOjYczmfR7ky+PZBLZU0d4X5u3Bxn7CplTUsrJdyFuYpOwNJH4MQG6DIMprwiLQyERZRWVLMiOYevdmey5XghWsOAcB9ujgtlYlQwXm7te35ewl2YT2vYPRdWPW20MBj5e7juIWlhICwmq7icr5Iy+XJXJkfzzuDkYMcNEYFMiQlhZO+AdrniRsJdtB0XtTCIMfZvDY42uyphQ7TWJGeW8OXuDJYkZVF4tgp3J3vGRAQxISq4XQW9hLtoew4shq+fgLJCGPYYjPittDAQFldTW8fW40V8vS+Llck5nCqrbldBL+Eu2qayIvjmGUiaC349jRYGXa4zuyphoy4X9B7ODozt15HJMcEM7eHf5pZWSriLtu3Yd7D0USg+CQNmww1/khYGwlQ1tXVsOV7IkqQsVu7PobSiBl93JyZEdWRKTAgJXXywawPtDywe7kopb+BdIBLQwN1a6y0NvV7CXVB11lgyufVNo4XBpP9Ar7FmVyUElTW1rDuUz9I9WaxOyaWiuo5OXi7c3D+E6f1D6RbQwbTazAj3D4GNWut3lVJOgJvWurih10u4ix+k7zD2b80/CH2nwg1/Bt9uZlclBABnK2tYnZLLl7sy2XjE6FrZP8yb6fGhTIruhJerZZdWWjTclVKewB6gm77GN5FwFxepqYTvX4ZNL0FtldFtcsRvpE+NaFNySyr4ancmCxIzOFK/tHJsfesDSzUzs3S4xwJvAweAGCAReFRrffaS180B5gCEhYXFp6WlNenzhBUrzYG1fzO2+HPygBG/hoH3yaoa0aZordmXeZqFiRks3pNFcf2Km5G9A7kpsiOjege0WjMzS4d7ArAVGKq13qaUehko0Vo/09AxMnIXV5SXAt/+EY58A15hRs/4yOnSxkC0OVU1xhexq/bn8M3+XArOVOJkb8eQHn7c1K8jYyICW3SfWEuHe0dgq9Y6vP7+cOBJrfXEho6RcBfX5Ph6Y/ennL3GHq5j/wpdh5tdlRCXVVun2X3yFKv257Byfw7pReUAxIR6MSYiiDERgfQN9mxWQzMzvlDdCNyrtT6klPoz4K61/k1Dr5dwF9esrg72zYc1z0FJhtFx8sZnIaC32ZUJ0SCtNQdzSlmTksvqlDz2ZBSjNQR7uTC6TyBjIgIZ0r3x8/RmhHssxlJIJ+A4cJfW+lRDr5dwF41WXW7s/LTxRWMZZfydRr+aDoFmVybEVeWXVrL2UB7fpeSx8Ug+Z6tqcXW0Z2gPf8ZEBDK6TyBBnlefvpGLmIT1OlsI6/8OO98DBxcY+ihc9yA4uZtdmRDXpLKmlm3Hi34Y1WcWG9M3USFejIkIZEyfIKJCvS57rIS7sH6Fx2D1nyBlKXgEw6inIPZ2sGu7fUGEuJTWmsO5Z1idkst3B/PYdfIUvQI9WPX4iMu+XsJd2I6TW40vXTN2QGA/GPsX6HGD2VUJ0SSFZyrJPl1BZEjLjdxljZlon8IGwz3fwi0fQnUZzJ0OH02DnH1mVyZEo/l1cG4w2JtKwl20X0pBv2nw4HYY9wJkJ8Fbw2HR/XA60+zqhDCVhLto/xycYPD98EgSDHkYkhfCq/1hzV+gosTs6oQwhYS7sB6u3jD2OXh4J0RMgY3/hlfiYPs7UFttdnVCWJSEu7A+3mEw/R2YvRYCI4zt/t4YDCnLjD1ehbABEu7CeoX0hzuXwk8/B2UHn98O/xsPGbJiS1g/CXdh3ZSC3uPg/i3GxiCFx+DdMfDFXVB0wuzqhGg1Eu7CNtg7QMLd8MguY5PuwyvhtQGw8iljj1chrIyEu7Atzh4w+ml4eBfEzIJtb8IrsbD5VWPzECGshIS7sE2ewTD1NfjlJggdYFzt+loC7FtgdKQUop2TcBe2Lagf3LEQfrYInL1g4T3GnHzq92ZXJkSzSLgLAdB9NNy3Hqa9CWdy4YMJ8OltUHDE7MqEaBIJdyHOsbOH2Nvg4URji78TG+D1QbDsV3Am3+zqhGgUCXchLuXoCsN/DY/sNlbYJH5gfOm64Z9QVWZ2dUJcEwl3IRrSIQAm/gse3AbdRsJ3f4VX42H3XKirNbs6Ia5Iwl2Iq/HvCbPmwV0rjFU2ix+E/46Ao2vMrkyIBkm4C3GtugyBe9fAjPehshTm/gQ+vhlyks2uTIgfkXAXojGUgsjp8NAOuOlvkLkL3hoGXz0IJVlmVyfEDyTchWgKB2djg+5Hk4yf++bDK/1hzXPGqF4Ik0m4C9Ecrj5w0/PGSL7PBNj4L6OH/I53pYe8MJWEuxAtwSfcmIuf/R3494Kvfw1vXAcHv5Ye8sIUEu5CtKSQePjF1zDrU+P+Z7fBBxMhM9HcuoTNkXAXoqUpZUzRPLAFJv4b8g/BO6NhwT1wKs3s6oSNaHa4K6XslVK7lVLLWqIgIayGvSMMuNe40nX4E8YUzWsJsOppKD9ldnXCyrXEyP1RIKUF3kcI6+TiCWOeMXrWRM2ELa/Dy7HGT+khL1pJs8JdKRUKTATebZlyhLBiXiEw7XX45UZjf9dVTxm7QSUvlC9dRYtr7sj9JeC3QIO7Gyil5iildiqldubnS2c9IegYZfSPv2MhOHWABXcbPeTTNptdmbAiTQ53pdQkIE9rfcVlAFrrt7XWCVrrhICAgKZ+nBDWp8cNxih+6uvG1a3/Gw8fTII9n0v3SdFszRm5DwWmKKVSgc+A0UqpuS1SlRC2ws4e4u4w9nS94Vk4nQ6L5sC/e8PSRyFjp0zZiCZRugX+4iilRgJPaK0nXel1CQkJeufOnc3+PCGsVl0dnNxstBXe/xXUlENAH+M/ANG3QodAsysUJlBKJWqtExpzjKxzF6ItsbOD8GFw81vwxGGY/DI4exgbeL8YYWz9d3C5tDYQV9UiI/drJSN3IZoo/5Axmt/zGZzNA/dAiJlljOgDeptdnWhlTRm5S7gL0Z7UVsORb42gP7wSdC2EDjBCvt9PjDX1wupIuAthS87kwd7PYdfHUHAIHFyh3zQj6LsMNdogCKsg4S6ELdLaaEy2+2PYtxCqSsGnK8TdDjG3GRdPiXZNwl0IW1dVBilLjaBP3Qgo6D7aGM33mWhsMiLanaaEu0NrFSOEMIGTG8TcatyKTkDSJ8ZtwV3GxiJRM42gD442u1LRymTkLoS1q6uF4+sgaR6kLIPaSqMFQtzPIOoWcPM1u0JxFTItI4S4srIio1HZ7o8hew/YOxnTNXF3QLdRxhWzos2RcBdCXLucfbB7nrHiprwIPEMg5qfGF7G+3cyuTlxAwl0I0Xg1lXBohbF2/tga0HXQZZgxmu87BZzcza7Q5km4CyGapyQL9nxqBH3RcXDygMibjfn50AGydt4kEu5CiJahNZzcUt/AbBFUl4F/r/oGZrPAI8jsCm2KhLsQouVVlhoBv3sepG8FZQ+9boLY242f9o5mV2j1ZJ27EKLlOXtA/58bt4Ij9Q3MPoVDy8E9wGhFHHcHBEaYXam4gIzchRCNV1sDR1dD0lzjy9i6GghJMFbaRE4HFy+zK7QqMi0jhLC8M/mwb77RwCw/xWhg1ndKfQOzYUaPetEsEu5CCPNoDVm7jGmbfQuh8jR4dzHm5mNvA+/OZlfYbkm4CyHahupyo9XB7o/hxHpAQbeR9Q3MJoGji9kVtivyhaoQom1wdIXoW4zbqbTzDcwW3mPMx0fNNObng2Nl7XwrkZG7EMIy6uogdYMxbXNgidHALCjSGM1HzQR3P7MrbLNkWkYI0T6Un6pvYDbPmKe3c4Q+EyD2DqP/vL1MKlxIpmWEEO2Dqw8MuNe45e6vb2D2GRxYDB7BRgOziEnQMUaCvolk5C6EaBtqqoxNv3fPhaPfGg3MnDpA50EQPtRYVtkpDhyczK7U4mTkLoRovxycjPXxfacYm3+nboS0zZD6Paz5S/1rXKHzACPouwwxmpnJypvLknAXQrQ9HQKNK10jpxv3zxYYjcxSv4e0TbDu/wBtbDYSklA/sh9ijPKlRTEg0zJCiPao/BSc3GYEfer3xq5SuhbsHIypmy5DjVvYYHDxNLvaZrPoahmlVGfgI6AjUAe8rbV++UrHSLgLIVpFZSmkb6sf2X8PmbugrhqUHXSMhvD6aZyw69rlnrGWDvdgIFhrvUsp5QEkAtO01gcaOkbCXQhhEVVlkLHDCPrU743faysBBUH9jKA/N7rvEGB2tVdl0S9UtdbZQHb976VKqRQgBGgw3IUQwiKc3KDb9cYNjK0EMxPPz9nvngvb3zae8+9thH34MCPsPYPNq7sFtcicu1IqHNgARGqtSy55bg4wByAsLCw+LS2t2Z8nhBDNUlsNWUnn5+xPboWqUuM53271I/thxhe13mGmlgomXaGqlOoArAee11p/eaXXyrSMEKJNqq2B3H31I/vNxnRORbHxnFdnY0QfXj+N49vN4v1wLB7uSilHYBmwSmv94tVeL+EuhGgX6uog74AR8ufm7csKjOc8gs/P2YcPM/aWbeWwt+icu1JKAe8BKdcS7EII0W7Y2UHHSOM26D6jV33B4fNBn/a90RsHwM3/4jn7wL5tYoOS5lzENBT4GbBPKZVU/9hTWuvlza5KCCHaEqUgoLdxS7jbCPui4+encFK/h5QlxmtdvC9YjTPEWIppQn+c5qyW2QRII2YhhO1RCvy6G7f+PzMeKz5Z3y5hkxH4h+rHuU4exsVU5+bsO8WBvWOrlyjtB4QQoiV4hxm3mFnG/ZLsi+fsV39rPO7oBp0Hnu+PExLfKv1xJNyFEKI1eAZD1AzjBsZG4ic3n5+zX/s8Rn8cZ6MB2rn+OKEDjXX6zSThLoQQltAhAPpONW4AZUXG+vpzo/sN/zTaHNs5Qkj/82vtwwY16eMk3IUQwgxuvsbuU30mGPcrSur749TP2W9+FTb9B5R9k95ewl0IIdoCF0/oeaNxA6g6C+nbjaDnmUa/nfmLMYUQQvyYkzt0HwWj/9CkwyXchRDCCkm4CyGEFZJwF0IIKyThLoQQVkjCXQghrJCEuxBCWCEJdyGEsEIS7kIIYYUk3IUQwgpJuAshhBWScBdCCCsk4S6EEFZIwl0IIayQhLsQQlghCXchhLBCEu5CCGGFJNyFEMIKSbgLIYQVknAXQggr1KxwV0qNU0odUkodVUo92VJFCSGEaJ4mh7tSyh54HRgP9AV+qpTq21KFCSGEaLrmjNwHAke11se11lXAZ8DUlilLCCFEczg049gQIP2C+xnAoEtfpJSaA8ypv1uplEpuxmdaE3+gwOwi2gg5F+fJuThPzsV5vRt7QHPCXV3mMf2jB7R+G3gbQCm1U2ud0IzPtBpyLs6Tc3GenIvz5Fycp5Ta2dhjmjMtkwF0vuB+KJDVjPcTQgjRQpoT7juAnkqprkopJ2AWsKRlyhJCCNEcTZ6W0VrXKKUeAlYB9sD7Wuv9Vzns7aZ+nhWSc3GenIvz5FycJ+fivEafC6X1j6bJhRBCtHNyhaoQQlghCXchhLBCFgl3W29ToJR6XymVd+Eaf6WUr1LqW6XUkfqfPmbWaAlKqc5KqbVKqRSl1H6l1KP1j9viuXBRSm1XSu2pPxfP1j9uc+fiHKWUvVJqt1JqWf19mzwXSqlUpdQ+pVTSuSWQTTkXrR7u0qYAgA+AcZc89iSwRmvdE1hTf9/a1QC/1lpHAIOBB+v/LtjiuagERmutY4BYYJxSajC2eS7OeRRIueC+LZ+LUVrr2AvW+Tf6XFhi5G7zbQq01huAoksengp8WP/7h8A0S9ZkBq11ttZ6V/3vpRj/kEOwzXOhtdZn6u861t80NnguAJRSocBE4N0LHrbJc9GARp8LS4T75doUhFjgc9u6IK11NhihBwSaXI9FKaXCgThgGzZ6LuqnIZKAPOBbrbXNngvgJeC3QN0Fj9nqudDAN0qpxPr2LdCEc9Gc9gPX6praFAjboZTqACwEHtNalyh1ub8i1k9rXQvEKqW8gUVKqUiTSzKFUmoSkKe1TlRKjTS5nLZgqNY6SykVCHyrlDrYlDexxMhd2hRcXq5SKhig/meeyfVYhFLKESPY52mtv6x/2CbPxTla62JgHcb3MrZ4LoYCU5RSqRjTtqOVUnOxzXOB1jqr/mcesAhjarvR58IS4S5tCi5vCXBn/e93AotNrMUilDFEfw9I0Vq/eMFTtnguAupH7CilXIEbgIPY4LnQWv9eax2qtQ7HyIfvtNZ3YIPnQinlrpTyOPc7MBZIpgnnwiJXqCqlJmDMqZ1rU/B8q39oG6KU+hQYidHCNBf4E/AVMB8IA04Ct2itL/3S1aoopYYBG4F9nJ9bfQpj3t3WzkU0xhdj9hiDrPla678opfywsXNxofppmSe01pNs8VwopbphjNbBmDb/RGv9fFPOhbQfEEIIKyRXqAohhBWScBdCCCsk4S6EEFZIwl0IIayQhLsQQlghCXchhLBCEu5CCGGF/h/Scy0tdR8K8QAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#slow\n",
"learn = synth_learner(cbs=ShowGraphCallback())\n",
"learn.fit(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(tensor([1.9139]), tensor([1.9139]), tensor([1.9139]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.predict(torch.tensor([[0.1]]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CSVLogger -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class CSVLogger(Callback):\n",
" \"Log the results displayed in `learn.path/fname`\"\n",
" order=60\n",
" def __init__(self, fname='history.csv', append=False):\n",
" self.fname,self.append = Path(fname),append\n",
"\n",
" def read_log(self):\n",
" \"Convenience method to quickly access the log.\"\n",
" return pd.read_csv(self.path/self.fname)\n",
"\n",
" def before_fit(self):\n",
" \"Prepare file with metric names.\"\n",
" if hasattr(self, \"gather_preds\"): return\n",
" self.path.parent.mkdir(parents=True, exist_ok=True)\n",
" self.file = (self.path/self.fname).open('a' if self.append else 'w')\n",
" self.file.write(','.join(self.recorder.metric_names) + '\\n')\n",
" self.old_logger,self.learn.logger = self.logger,self._write_line\n",
"\n",
" def _write_line(self, log):\n",
" \"Write a line with `log` and call the old logger.\"\n",
" self.file.write(','.join([str(t) for t in log]) + '\\n')\n",
" self.file.flush()\n",
" os.fsync(self.file.fileno())\n",
" self.old_logger(log)\n",
"\n",
" def after_fit(self):\n",
" \"Close the file and clean up.\"\n",
" if hasattr(self, \"gather_preds\"): return\n",
" self.file.close()\n",
" self.learn.logger = self.old_logger"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The results are appended to an existing file if `append`, or they overwrite it otherwise."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 10.500990 | \n",
" 8.331024 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 9.115092 | \n",
" 5.783391 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 7.573916 | \n",
" 3.695323 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 3 | \n",
" 6.161108 | \n",
" 2.222861 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 4 | \n",
" 4.948495 | \n",
" 1.308835 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner(cbs=CSVLogger())\n",
"learn.fit(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> CSVLogger.read_log
()\n",
"\n",
"Convenience method to quickly access the log."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(CSVLogger.read_log)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = learn.csv_logger.read_log()\n",
"test_eq(df.columns.values, learn.recorder.metric_names)\n",
"for i,v in enumerate(learn.recorder.values):\n",
" test_close(df.iloc[i][:3], [i] + v)\n",
"os.remove(learn.path/learn.csv_logger.fname)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> CSVLogger.before_fit
()\n",
"\n",
"Prepare file with metric names."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(CSVLogger.before_fit)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> CSVLogger.after_fit
()\n",
"\n",
"Close the file and clean up."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(CSVLogger.after_fit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_torch_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 01a_losses.ipynb.\n",
"Converted 02_data.load.ipynb.\n",
"Converted 03_data.core.ipynb.\n",
"Converted 04_data.external.ipynb.\n",
"Converted 05_data.transforms.ipynb.\n",
"Converted 06_data.block.ipynb.\n",
"Converted 07_vision.core.ipynb.\n",
"Converted 08_vision.data.ipynb.\n",
"Converted 09_vision.augment.ipynb.\n",
"Converted 09b_vision.utils.ipynb.\n",
"Converted 09c_vision.widgets.ipynb.\n",
"Converted 10_tutorial.pets.ipynb.\n",
"Converted 10b_tutorial.albumentations.ipynb.\n",
"Converted 11_vision.models.xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_callback.core.ipynb.\n",
"Converted 13a_learner.ipynb.\n",
"Converted 13b_metrics.ipynb.\n",
"Converted 14_callback.schedule.ipynb.\n",
"Converted 14a_callback.data.ipynb.\n",
"Converted 15_callback.hook.ipynb.\n",
"Converted 15a_vision.models.unet.ipynb.\n",
"Converted 16_callback.progress.ipynb.\n",
"Converted 17_callback.tracker.ipynb.\n",
"Converted 18_callback.fp16.ipynb.\n",
"Converted 18a_callback.training.ipynb.\n",
"Converted 18b_callback.preds.ipynb.\n",
"Converted 19_callback.mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision.learner.ipynb.\n",
"Converted 22_tutorial.imagenette.ipynb.\n",
"Converted 23_tutorial.vision.ipynb.\n",
"Converted 24_tutorial.siamese.ipynb.\n",
"Converted 24_vision.gan.ipynb.\n",
"Converted 30_text.core.ipynb.\n",
"Converted 31_text.data.ipynb.\n",
"Converted 32_text.models.awdlstm.ipynb.\n",
"Converted 33_text.models.core.ipynb.\n",
"Converted 34_callback.rnn.ipynb.\n",
"Converted 35_tutorial.wikitext.ipynb.\n",
"Converted 36_text.models.qrnn.ipynb.\n",
"Converted 37_text.learner.ipynb.\n",
"Converted 38_tutorial.text.ipynb.\n",
"Converted 39_tutorial.transformers.ipynb.\n",
"Converted 40_tabular.core.ipynb.\n",
"Converted 41_tabular.data.ipynb.\n",
"Converted 42_tabular.model.ipynb.\n",
"Converted 43_tabular.learner.ipynb.\n",
"Converted 44_tutorial.tabular.ipynb.\n",
"Converted 45_collab.ipynb.\n",
"Converted 46_tutorial.collab.ipynb.\n",
"Converted 50_tutorial.datablock.ipynb.\n",
"Converted 60_medical.imaging.ipynb.\n",
"Converted 61_tutorial.medical_imaging.ipynb.\n",
"Converted 65_medical.text.ipynb.\n",
"Converted 70_callback.wandb.ipynb.\n",
"Converted 71_callback.tensorboard.ipynb.\n",
"Converted 72_callback.neptune.ipynb.\n",
"Converted 73_callback.captum.ipynb.\n",
"Converted 97_test_utils.ipynb.\n",
"Converted 99_pytorch_doc.ipynb.\n",
"Converted dev-setup.ipynb.\n",
"Converted index.ipynb.\n",
"Converted quick_start.ipynb.\n",
"Converted tutorial.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}