{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "2e433cbdd5b1bc32ffca46551a708e45", "grade": false, "grade_id": "cell-c290b2da5fe2edf3", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "# Part 2: Loading a saved model" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "f548e96caa143d15f7b0ec97bd4e149f", "grade": false, "grade_id": "cell-ba8019f876600bdf", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "__Before starting, we recommend you enable GPU acceleration if you're running on Colab. You'll also need to upload the weights you downloaded previously using the following block and using the upload button to upload your bettercnn.weights file:__" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "c0d2caf75989226e03a8ded7438278a5", "grade": false, "grade_id": "cell-7baa302f182176c7", "locked": true, "schema_version": 1, "solution": false } }, "outputs": [], "source": [ "# Execute this code block to install dependencies when running on colab\n", "try:\n", " import torch\n", "except:\n", " from os.path import exists\n", " from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", " platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", " cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n", " accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n", "\n", " !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision\n", " \n", "try: \n", " import torchbearer\n", "except:\n", " !pip install torchbearer\n", "\n", "try:\n", " from google.colab import files\n", " uploaded = files.upload()\n", "except:\n", " print(\"Not running on colab. Ignoring.\")\n", "\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/0.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/1.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/2.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/3.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/4.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/5.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/6.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/7.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/8.PNG\n", "!wget http://comp6248.ecs.soton.ac.uk/labs/lab5/9.PNG" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "ce20db2649381e63c13307bcd496ab41", "grade": false, "grade_id": "cell-05dc06c4f046cee9", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "## Reading models and propagating input\n", "\n", "At this point, we know how to train a model and how to save the resultant weights. Let's assume we're in the business of building a real system for handwritten character recognition; we need to be able to read in a previously trained model and forward propagate an image from outside the MNIST dataset through it in order to generate a prediction. Let's build some code to do just that. Firstly we need to load the model we saved in the previous part of the lab; PyTorch doesn't save the model structure by default, so you'll need to copy-paste the `BetterCNN` `forward` method implementation from the previous workbook into the block below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "deletable": false, "nbgrader": { "checksum": "83517d0b5dd9912b14e753371891acda", "grade": false, "grade_id": "cell-25d7ce8447ab7c15", "locked": false, "schema_version": 1, "solution": true } }, "outputs": [], "source": [ "%matplotlib inline\n", "# automatically reload external modules if they change\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import torch \n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt\n", "from torch import nn\n", "\n", "import torch \n", "import torch.nn.functional as F\n", "from torch import nn\n", "\n", "# Model Definition\n", "class BetterCNN(nn.Module):\n", " \n", " def __init__(self):\n", " super(BetterCNN, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 30, (5, 5), padding=0)\n", " self.conv2 = nn.Conv2d(30, 15, (3, 3), padding=0)\n", " self.fc1 = nn.Linear(15 * 5**2, 128)\n", " self.fc2 = nn.Linear(128, 50)\n", " self.fc3 = nn.Linear(50, 10)\n", " \n", " def forward(self, x):\n", " # YOUR CODE HERE\n", " raise NotImplementedError()\n", "\n", "# build the model and load state\n", "model = BetterCNN()\n", "model.load_state_dict(torch.load('bettercnn.weights'))\n", "\n", "# put model in eval mode\n", "model = model.eval() " ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "97971f78e75437a324758a0adf66779f", "grade": false, "grade_id": "cell-05d9eb94a8c5425c", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "We've provided a set of images you can try. Let's load the one corresponding to a '1', convert it to a tensor, and display it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "c711239daefae20c86dd0d9036bdfacd", "grade": false, "grade_id": "cell-4bd7da9513437358", "locked": true, "schema_version": 1, "solution": false } }, "outputs": [], "source": [ "from PIL import Image\n", "import torchvision\n", "\n", "transform = torchvision.transforms.ToTensor()\n", "im = transform(Image.open(\"1.PNG\"))\n", "\n", "plt.imshow(im[0], cmap=plt.get_cmap('gray'))" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "56476fa5813e1c34650810d8eed151ef", "grade": false, "grade_id": "cell-6cf292098a212188", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "Now we'll use the model to make a prediction. The model expects input to have a batch dimension, so we use `unsqueeze(0)` to prepend one to the image. Recall that the model outputs the logits of the classes; the index of the biggest one will tell us which class has been predicted." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "ff7a048101ce27cbe780653cd8468eca", "grade": false, "grade_id": "cell-92042c8f7b09a26d", "locked": true, "schema_version": 1, "solution": false } }, "outputs": [], "source": [ "batch = im.unsqueeze(0)\n", "predictions = model(batch)\n", "\n", "print(\"logits:\", predictions.data)\n", "\n", "_, predicted_class = predictions.max(1)\n", "\n", "print(\"predicted class:\", predicted_class.item())" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "e83d2f7a42effd9370d399d819d592e3", "grade": false, "grade_id": "cell-dbe9d30ed68054cf", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "We've provided images `0.PNG` through to `9.PNG` for you to play with. Use the following code block to classify each image and print the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "deletable": false, "nbgrader": { "checksum": "2c942b974467a9a456a2e9e4389b6f7c", "grade": false, "grade_id": "cell-27a634204f34e601", "locked": false, "schema_version": 1, "solution": true } }, "outputs": [], "source": [ "# YOUR CODE HERE\n", "raise NotImplementedError()" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "editable": false, "nbgrader": { "checksum": "7126e6798f657baa53e1d4e360614e02", "grade": false, "grade_id": "cell-b1f2a02a37c3f405", "locked": true, "schema_version": 1, "solution": false } }, "source": [ "__Answer the following question (enter the answer in the box below each one):__\n", "\n", "__1.__ How many images were missclassified? Which images?" ] }, { "cell_type": "markdown", "metadata": { "deletable": false, "nbgrader": { "checksum": "2bfced6d006c0b957ba58544a03b12f2", "grade": true, "grade_id": "cell-092c153f41f2dd1d", "locked": false, "points": 2, "schema_version": 1, "solution": true } }, "source": [ "YOUR ANSWER HERE" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }