{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Serializing a Trial\n", "\n", "This guide will explain the two different ways to how to save and reload your results from a Trial.\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup.\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.4.0.dev\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up a Mock Example\n", "\n", "Let's assume we have a basic binary classification task where we have 100-dimensional samples as input and a binary label as output.\n", "Let's also assume that we would like to solve this problem with a 2-layer neural network.\n", "Finally, we also want to keep track of the sum of hidden outputs for some arbitrary reason. Therefore we use the state functionality of Torchbearer.\n", "\n", "We create a state key for the mock sum we wanted to track using state." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "MOCK = torchbearer.state_key('mock')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is our basic 2-layer neural network." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "class BasicModel(nn.Module):\n", " def __init__(self):\n", " super(BasicModel, self).__init__()\n", " self.linear1 = nn.Linear(100, 25)\n", " self.linear2 = nn.Linear(25, 1)\n", "\n", " def forward(self, x, state):\n", " x = self.linear1(x)\n", " # The following step is here to showcase a useless but simple of example a forward method that uses state\n", " state[MOCK] = torch.sum(x)\n", " x = self.linear2(x)\n", " return torch.sigmoid(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create some random training dataset and put them in a DataLoader." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "n_sample = 100\n", "X = torch.rand(n_sample, 100)\n", "y = torch.randint(0, 2, [n_sample, 1]).float()\n", "traingen = DataLoader(TensorDataset(X, y))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's say we would like to save the model every time we get a better training loss. Torchbearer's [`Best` checkpoint callback](https://torchbearer.readthedocs.io/en/latest/code/callbacks.html?highlight=best#torchbearer.callbacks.checkpointers.Best) is perfect for this job. We then run the model for 3 epochs." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "480fae2126e84730aa57f45fbc2f7b8c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "08c5cbe35a0d4222adefba87afaaed24", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b88f056ae4c4de783fc3066c1dc458c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import torch.optim as optim\n", "import torch.nn.functional as F\n", "\n", "from torchbearer import Trial\n", "\n", "model = BasicModel()\n", "# Create a checkpointer that track val_loss and saves a model.pt whenever we get a better loss\n", "checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss')\n", "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],\n", " callbacks=[checkpointer])\n", "torchbearer_trial.with_train_generator(traingen)\n", "_ = torchbearer_trial.run(epochs=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reloading the Trial for More Epochs\n", "\n", "Given we recreate the exact same Trial structure, we can easily resume our run from the last checkpoint. The following code block shows how it's done. Remember here that the ``epochs`` parameter we pass to Trial acts cumulative. In other words, the following run will complement the entire training to a total of 6 epochs." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cea3e6b6670d487ba01e4be5b24f81fd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='3/6(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8cfec1e1359a4d7cb4f40416c19788d8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='4/6(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e8f7bff0ddd749068b7fc7e1c619dfea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='5/6(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "state_dict = torch.load('model.pt')\n", "model = BasicModel()\n", "trial_reloaded = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],\n", " callbacks=[checkpointer])\n", "trial_reloaded.load_state_dict(state_dict)\n", "trial_reloaded.with_train_generator(traingen)\n", "_ = trial_reloaded.run(epochs=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trying to Reload to a PyTorch Module\n", "\n", "We try to load the ``state_dict`` to a regular PyTorch Module, as described in PyTorch's own documentation [here](https://pytorch.org/docs/stable/notes/serialization.html)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "'StateKey' object has no attribute 'startswith'\n" ] } ], "source": [ "model = BasicModel()\n", "try:\n", " model.load_state_dict(state_dict)\n", "except AttributeError as e:\n", " print(\"\\n\")\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This gives an error. The reason is that the `state_dict` has Trial related attributes that are unknown to a native PyTorch model. This is why we have the `save_model_params_only`\n", "option for our checkpointers. We try again with that option" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f5327c79a4274526b36988cc2eabe0d6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9ec5a7202164ffba10a765c42555d80", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "63ad0eb527ea49ea9f1befc019158126", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "model = BasicModel()\n", "checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss', save_model_params_only=True)\n", "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],\n", " callbacks=[checkpointer])\n", "torchbearer_trial.with_train_generator(traingen)\n", "torchbearer_trial.run(epochs=3)\n", "\n", "# Try once again to load the module, forward another random sample for testing\n", "state_dict = torch.load('model.pt')\n", "model = BasicModel()\n", "_ = model.load_state_dict(state_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "No errors this time, but we still have to test. Here is a test sample and we run it through the model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "forward() missing 1 required positional argument: 'state'\n" ] } ], "source": [ "X_test = torch.rand(5, 100)\n", "try:\n", " model(X_test)\n", "except TypeError as e:\n", " print(\"\\n\")\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we get a different error, stating that we should also be passing ``state`` as an argument to module's forward. This should not be a surprise as we defined ``state`` parameter in the forward method of ``BasicModule`` as a required argument.\n", "\n", "## Robust Signature for Module\n", "\n", "We define the model with a better signature this time, so it gracefully handles the problem above." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class BetterSignatureModel(nn.Module):\n", " def __init__(self):\n", " super(BetterSignatureModel, self).__init__()\n", " self.linear1 = nn.Linear(100, 25)\n", " self.linear2 = nn.Linear(25, 1)\n", "\n", " def forward(self, x, **state):\n", " x = self.linear1(x)\n", " # Using kwargs instead of state is safer from a serialization perspective\n", " if state is not None:\n", " state = state\n", " state[MOCK] = torch.sum(x)\n", " x = self.linear2(x)\n", " return torch.sigmoid(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we wrap it up once again to test the new definition of the model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79d831aee9984726ae207996b4b5627d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bbe9d3dceecc4c96a4f3b908d1b7b981", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='1/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bfd8bdbf1ffc4ceaa2326e2448013de1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='2/3(t)'), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "tensor([[0.4943],\n", " [0.5023],\n", " [0.5058],\n", " [0.4995],\n", " [0.4983]], grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = BetterSignatureModel()\n", "checkpointer = torchbearer.callbacks.checkpointers.Best(filepath='model.pt', monitor='loss', save_model_params_only=True)\n", "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "torchbearer_trial = Trial(model, optimizer=optimizer, criterion=F.binary_cross_entropy, metrics=['loss'],\n", " callbacks=[checkpointer])\n", "torchbearer_trial.with_train_generator(traingen)\n", "torchbearer_trial.run(epochs=3)\n", "\n", "# This time, the forward function should work without the need for a state argument\n", "state_dict = torch.load('model.pt')\n", "model = BetterSignatureModel()\n", "model.load_state_dict(state_dict)\n", "X_test = torch.rand(5, 100)\n", "model(X_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }