{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using SGD on MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ... about machine learning (a reminder from lesson 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The good news is that modern machine learning can be distilled down to a couple of key techniques that are of very wide applicability. Recent studies have shown that the vast majority of datasets can be best modeled with just two methods:\n",
"\n",
"1. Ensembles of decision trees (i.e. Random Forests and Gradient Boosting Machines), mainly for structured data (such as you might find in a database table at most companies). We looked at random forests in depth as we analyzed the Blue Book for Bulldozers dataset.\n",
"\n",
"2. Multi-layered neural networks learnt with SGD (i.e. shallow and/or deep learning), mainly for unstructured data (such as audio, vision, and natural language)\n",
"\n",
"In this lesson, we will start on the 2nd approach (a neural network with SGD) by analyzing the MNIST dataset. You may be surprised to learn that **logistic regression is actually an example of a simple neural net**!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### About The Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this lesson, we will be working with MNIST, a classic data set of hand-written digits. Solutions to this problem are used by banks to automatically recognize the amounts on checks, and by the postal service to automatically recognize zip codes on mail."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A matrix can represent an image, by creating a grid where each entry corresponds to a different pixel.\n",
"\n",
"\n",
" (Source: [Adam Geitgey\n",
"](https://medium.com/@ageitgey/machine-learning-is-fun-part-3-deep-learning-and-convolutional-neural-networks-f40359318721))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports and data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will be using the fastai library, which is still in pre-alpha. If you are accessing this course notebook, you probably already have it downloaded, as it is in the same Github repo as the course materials.\n",
"\n",
"We use [symbolic links](https://kb.iu.edu/d/abbe) (often called *symlinks*) to make it possible to import these from your current directory. For instance, I ran:\n",
"\n",
" ln -s ../../fastai\n",
" \n",
"in the terminal, within the directory I'm working in, `home/fastai/courses/ml1`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.imports import *\n",
"from fastai.torch_imports import *\n",
"from fastai.io import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = 'data/mnist/'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's download, unzip, and format the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.makedirs(path, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"URL='http://deeplearning.net/data/mnist/'\n",
"FILENAME='mnist.pkl.gz'\n",
"\n",
"def load_mnist(filename):\n",
" return pickle.load(gzip.open(filename, 'rb'), encoding='latin-1')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"get_data(URL+FILENAME, path+FILENAME)\n",
"((x, y), (x_valid, y_valid), _) = load_mnist(path+FILENAME)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(numpy.ndarray, (50000, 784), numpy.ndarray, (50000,))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(x), x.shape, type(y), y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Normalize"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many machine learning algorithms behave better when the data is *normalized*, that is when the mean is 0 and the standard deviation is 1. We will subtract off the mean and standard deviation from our training set in order to normalize the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean = x.mean()\n",
"std = x.std()\n",
"\n",
"x=(x-mean)/std\n",
"mean, std, x.mean(), x.std()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that for consistency (with the parameters we learn when training), we subtract the mean and standard deviation of our training set from our validation set. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-0.0058509219, 0.99243325)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_valid = (x_valid-mean)/std\n",
"x_valid.mean(), x_valid.std()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Look at the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In any sort of data science work, it's important to look at your data, to make sure you understand the format, how it's stored, what type of values it holds, etc. To make it easier to work with, let's reshape it into 2d images from the flattened 1d format."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Helper methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show(img, title=None):\n",
" plt.imshow(img, cmap=\"gray\")\n",
" if title is not None: plt.title(title)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plots(ims, figsize=(12,6), rows=2, titles=None):\n",
" f = plt.figure(figsize=figsize)\n",
" cols = len(ims)//rows\n",
" for i in range(len(ims)):\n",
" sp = f.add_subplot(rows, cols, i+1)\n",
" sp.axis('Off')\n",
" if titles is not None: sp.set_title(titles[i], fontsize=16)\n",
" plt.imshow(ims[i], cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Plots "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10000, 784)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_valid.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10000, 28, 28)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_imgs = np.reshape(x_valid, (-1,28,28)); x_imgs.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADl1JREFUeJzt3X3IXHV6xvHr8mUxMaLRVE1iNLtPhL5htAapKEVd3Nqt\nEFdwMWBJoyVrWaGrVSpBUBTBlu5qK1SJGMzirltN3FVWRcXa+gZifGGNGzer4saYJ09Qi4mou43e\n/eM5WR7jzG8mM2fmTHJ/P/AwM+eeM+dmyJVzzvzOzM8RIQD57Nd0AwCaQfiBpAg/kBThB5Ii/EBS\nhB9IivADSRF+tGT7btvjtrfb3mj775ruCfUyF/mgFdt/IumNiPit7T+U9N+S/joiXmy2M9SFPT9a\niojXIuK3ux5Wf2MNtoSaEX60Zfs/bH8s6XVJ45Iebrgl1IjDfhTZ3l/SqZLOkPTPEfF/zXaEurDn\nR1FEfBYRz0g6RtLfN90P6kP40a0DxDn/PoXw40tsH2n7QtszbO9v+y8lLZH0X033hvpwzo8vsf0H\nktZIWqjJHcRvJP17RNzRaGOoFeEHkuKwH0iK8ANJEX4gKcIPJHXAMDdmm08XgQGLCHfzvL72/LbP\nsf0r22/Yvrqf1wIwXD0P9VXXfG+UdLakzZJekLQkIn5ZWIc9PzBgw9jzn6LJ73u/FRG/k/QTSYv7\neD0AQ9RP+OdKemfK483Vsi+wvdz2Otvr+tgWgJr184Ffq0OLLx3WR8RKSSslDvuBUdLPnn+zpHlT\nHh8jaUt/7QAYln7C/4Kk421/1fZXJF0o6cF62gIwaD0f9kfETtuXSXpU0v6SVkXEa7V1BmCghvqt\nPs75gcEbykU+APZehB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4\ngaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kNdQputGbhQsXFuuXX35529rY2Fhx3enTpxfrK1as\nKNYPPfTQYv2RRx5pW9uxY0dxXQwWe34gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIpZekfAjBkzivVN\nmzYV64cddlid7dTq3XffbVsrXZ8gSWvWrKm7nRS6naW3r4t8bL8taYekzyTtjIhF/bwegOGp4wq/\nMyPivRpeB8AQcc4PJNVv+EPSY7ZftL281RNsL7e9zva6PrcFoEb9HvafFhFbbB8p6XHbr0fEU1Of\nEBErJa2U+MAPGCV97fkjYkt1u03STyWdUkdTAAav5/DbPtj2IbvuS/qGpPV1NQZgsHoe57f9NU3u\n7aXJ04cfR8SNHdbhsL+FQw45pFh/+OGHi/X333+/be3ll18urnvSSScV68cdd1yxPm/evGJ92rRp\nbWsTExPFdU899dRivdP6WQ18nD8i3pJU/pUJACOLoT4gKcIPJEX4gaQIP5AU4QeS4iu96MusWbOK\n9auuuqqnmiQtW7asWF+9enWxnlW3Q33s+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKaboRl/ee6/8\n263PPvts21qncf5OXzdmnL8/7PmBpAg/kBThB5Ii/EBShB9IivADSRF+ICnG+dGXmTNnFusrVqzo\n+bXnzJnT87rojD0/kBThB5Ii/EBShB9IivADSRF+ICnCDyTF7/ajaOHC8kTM9913X7G+YMGCtrWN\nGzcW1z377LOL9XfeeadYz6q23+23vcr2Ntvrpyw73Pbjtn9d3Zav9AAwcro57L9L0jm7Lbta0hMR\ncbykJ6rHAPYiHcMfEU9J+mC3xYsl7foNpdWSzqu5LwAD1uu1/UdFxLgkRcS47SPbPdH2cknLe9wO\ngAEZ+Bd7ImKlpJUSH/gBo6TXob4J27MlqbrdVl9LAIah1/A/KGlpdX+ppAfqaQfAsHQc57d9j6Qz\nJM2SNCHpWkk/k3SvpGMlbZJ0QUTs/qFgq9fisH/ELF26tFi//vrri/V58+YV65988knb2rnnnltc\n98knnyzW0Vq34/wdz/kjYkmb0tf3qCMAI4XLe4GkCD+QFOEHkiL8QFKEH0iKn+7eB8yYMaNt7cor\nryyue8011xTr++1X3j988EF5hPf0009vW3v99deL62Kw2PMDSRF+ICnCDyRF+IGkCD+QFOEHkiL8\nQFKM8+8D7rrrrra1888/v6/XXrNmTbF+yy23FOuM5Y8u9vxAUoQfSIrwA0kRfiApwg8kRfiBpAg/\nkBTj/PuAsbGxgb32bbfdVqw/99xzA9s2Bos9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxTj/PuCx\nxx5rW1u4cOHAXlvqfB3ATTfd1La2ZcuWnnpCPTru+W2vsr3N9vopy66z/a7tV6q/bw62TQB16+aw\n/y5J57RYfnNEnFj9PVxvWwAGrWP4I+IpSeU5mQDsdfr5wO8y27+oTgtmtnuS7eW219le18e2ANSs\n1/DfJmlM0omSxiV9v90TI2JlRCyKiEU9bgvAAPQU/oiYiIjPIuJzSXdIOqXetgAMWk/htz17ysNv\nSVrf7rkARpMjovwE+x5JZ0iaJWlC0rXV4xMlhaS3JX0nIsY7bswubww9mTZtWtva3XffXVz35JNP\nLtaPPfbYnnraZevWrW1ry5YtK6776KOP9rXtrCLC3Tyv40U+EbGkxeI797gjACOFy3uBpAg/kBTh\nB5Ii/EBShB9IquNQX60bY6hv6A466KBi/YADygM+27dvr7OdL/j000+L9SuuuKJYv/322+tsZ5/R\n7VAfe34gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIpxfhSdcMIJxfrNN99crJ955pk9b3vTpk3F+vz5\n83t+7X0Z4/wAigg/kBThB5Ii/EBShB9IivADSRF+ICnG+UfA9OnTi/WPP/54SJ3suZkz287UJkla\ntWpV29rixYv72vbcuXOL9fHxjr8mv09inB9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJNVxll7b8yT9\nUNLRkj6XtDIi/s324ZL+U9J8TU7T/e2I+N/Btbr3GhsbK9afeeaZYv2hhx4q1tevX9+21mms+5JL\nLinWDzzwwGK901j7ggULivWSN998s1jPOo5fl272/Dsl/WNE/JGkP5f0Xdt/LOlqSU9ExPGSnqge\nA9hLdAx/RIxHxEvV/R2SNkiaK2mxpNXV01ZLOm9QTQKo3x6d89ueL+kkSc9LOioixqXJ/yAkHVl3\ncwAGp+M5/y62Z0haK+l7EbHd7uryYdleLml5b+0BGJSu9vy2D9Rk8H8UEfdXiydsz67qsyVta7Vu\nRKyMiEURsaiOhgHUo2P4PbmLv1PShoj4wZTSg5KWVveXSnqg/vYADEo3h/2nSfobSa/afqVatkLS\nTZLutX2JpE2SLhhMi3u/Cy4ovzVHH310sX7xxRfX2c4e6XR6189Xwj/66KNi/dJLL+35tdFZx/BH\nxDOS2v0L+Hq97QAYFq7wA5Ii/EBShB9IivADSRF+ICnCDyTV9eW96N0RRxzRdAsDs3bt2mL9hhtu\naFvbtq3lRaG/t3Xr1p56QnfY8wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUkzRPQSdfv76rLPOKtYv\nuuiiYn3OnDltax9++GFx3U5uvfXWYv3pp58u1nfu3NnX9rHnmKIbQBHhB5Ii/EBShB9IivADSRF+\nICnCDyTFOD+wj2GcH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8k1TH8tufZftL2Btuv2f6Havl1tt+1\n/Ur1983BtwugLh0v8rE9W9LsiHjJ9iGSXpR0nqRvS/ooIv61641xkQ8wcN1e5NNxxp6IGJc0Xt3f\nYXuDpLn9tQegaXt0zm97vqSTJD1fLbrM9i9sr7I9s806y22vs72ur04B1Krra/ttz5D0P5JujIj7\nbR8l6T1JIekGTZ4aXNzhNTjsBwas28P+rsJv+0BJP5f0aET8oEV9vqSfR8Sfdngdwg8MWG1f7LFt\nSXdK2jA1+NUHgbt8S9L6PW0SQHO6+bT/dElPS3pV0ufV4hWSlkg6UZOH/W9L+k714WDptdjzAwNW\n62F/XQg/MHh8nx9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kR\nfiCpjj/gWbP3JP1myuNZ1bJRNKq9jWpfEr31qs7ejuv2iUP9Pv+XNm6vi4hFjTVQMKq9jWpfEr31\nqqneOOwHkiL8QFJNh39lw9svGdXeRrUvid561UhvjZ7zA2hO03t+AA0h/EBSjYTf9jm2f2X7DdtX\nN9FDO7bftv1qNe14o/MLVnMgbrO9fsqyw20/bvvX1W3LORIb6m0kpm0vTCvf6Hs3atPdD/2c3/b+\nkjZKOlvSZkkvSFoSEb8caiNt2H5b0qKIaPyCENt/IekjST/cNRWa7X+R9EFE3FT9xzkzIv5pRHq7\nTns4bfuAems3rfzfqsH3rs7p7uvQxJ7/FElvRMRbEfE7ST+RtLiBPkZeRDwl6YPdFi+WtLq6v1qT\n/3iGrk1vIyEixiPiper+Dkm7ppVv9L0r9NWIJsI/V9I7Ux5vVoNvQAsh6THbL9pe3nQzLRy1a1q0\n6vbIhvvZXcdp24dpt2nlR+a962W6+7o1Ef5WUwmN0njjaRHxZ5L+StJ3q8NbdOc2SWOanMNxXNL3\nm2ymmlZ+raTvRcT2JnuZqkVfjbxvTYR/s6R5Ux4fI2lLA320FBFbqtttkn6qydOUUTKxa4bk6nZb\nw/38XkRMRMRnEfG5pDvU4HtXTSu/VtKPIuL+anHj712rvpp635oI/wuSjrf9VdtfkXShpAcb6ONL\nbB9cfRAj2wdL+oZGb+rxByUtre4vlfRAg718wahM295uWnk1/N6N2nT3jVzhVw1l3CJpf0mrIuLG\noTfRgu2vaXJvL01+3fnHTfZm+x5JZ2jyK58Tkq6V9DNJ90o6VtImSRdExNA/eGvT2xnaw2nbB9Rb\nu2nln1eD712d093X0g+X9wI5cYUfkBThB5Ii/EBShB9IivADSRF+ICnCDyT1/x2VQ9c6BSuMAAAA\nAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show(x_imgs[0], y_valid[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10000,)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_valid.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's the digit 3! And that's stored in the y value:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_valid[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can look at part of an image:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.42452, -0.42452, -0.42452, -0.42452, 0.17294],\n",
" [-0.42452, -0.42452, -0.42452, 0.78312, 2.43567],\n",
" [-0.42452, -0.27197, 1.20261, 2.77889, 2.80432],\n",
" [-0.42452, 1.76194, 2.80432, 2.80432, 1.73651],\n",
" [-0.42452, 2.20685, 2.80432, 2.80432, 0.40176]], dtype=float32)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_imgs[0,10:15,10:15]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD8CAYAAABaQGkdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACPxJREFUeJzt3UGIVYUex/Hfz3lKgg9aPBfhyDOi4knwFEQCd9LCMgpc\nKRQtgtm8wCCMAjfRPtrUYqhIKIqoFhE9QiiJoFepWeSbCokeSYE+MqpNMvV/i7kL6TneM95z5tzz\n4/uBgbl6uP6K+c65987ljKtKADKt6XsAgO4QOBCMwIFgBA4EI3AgGIEDwQgcCEbgQDACB4L9qYs7\ntc3b4zA469ev73tCYxcvXtTi4qLHHddJ4MAQ3XTTTX1PaOyrr75qdBwP0YFgBA4EI3AgGIEDwQgc\nCEbgQDACB4IROBCMwIFgBA4EI3AgGIEDwQgcCEbgQDACB4IROBCsUeC299j+0vYZ2490PQpAO8YG\nbntG0lOSbpe0VdIB21u7HgZgck3O4Dslnamqr6vqoqSXJd3d7SwAbWgS+CZJ315y++zozwBMuSYX\nXbzclRv/76qptuckzU28CEBrmgR+VtLmS27PSvrujwdV1bykeYnLJgPToslD9I8l3Wj7etvrJO2X\n9Ea3swC0YewZvKoWbT8g6W1JM5Keq6rTnS8DMLFGv/igqt6S9FbHWwC0jHeyAcEIHAhG4EAwAgeC\nETgQjMCBYAQOBCNwIBiBA8EIHAhG4EAwAgeCETgQjMCBYAQOBCNwIBiBA8EaXdEFuFr33Xdf3xMa\ne/zxx/ue0NjevXsbHccZHAhG4EAwAgeCETgQjMCBYAQOBCNwIBiBA8EIHAhG4EAwAgeCETgQjMCB\nYAQOBCNwIBiBA8EIHAg2NnDbz9k+Z/vz1RgEoD1NzuDPS9rT8Q4AHRgbeFW9J+mHVdgCoGU8BweC\ntXZVVdtzkubauj8Ak2st8KqalzQvSbarrfsFcPV4iA4Ea/JjspckfSDpZttnbd/f/SwAbRj7EL2q\nDqzGEADt4yE6EIzAgWAEDgQjcCAYgQPBCBwIRuBAMAIHghE4EIzAgWAEDgQjcCAYgQPBCBwIRuBA\nMAIHgrV2Tbah2rBhQ98TVuTQoUN9T1iRw4cP9z2hsTVrhnO+W7duXaPjhvNfBGDFCBwIRuBAMAIH\nghE4EIzAgWAEDgQjcCAYgQPBCBwIRuBAMAIHghE4EIzAgWAEDgQjcCAYgQPBxgZue7Ptd20v2D5t\n++BqDAMwuSaXbFqU9FBVnbT9Z0knbB+tqn93vA3AhMaewavq+6o6Ofr8Z0kLkjZ1PQzA5Fb0HNz2\nFknbJX3YxRgA7Wp8VVXbGyS9JunBqvrpMn8/J2muxW0AJtQocNtrtRT3i1X1+uWOqap5SfOj46u1\nhQCuWpNX0S3pWUkLVfVE95MAtKXJc/Bdku6VtNv2qdHHHR3vAtCCsQ/Rq+p9SV6FLQBaxjvZgGAE\nDgQjcCAYgQPBCBwIRuBAMAIHghE4EIzAgWAEDgQjcCAYgQPBCBwIRuBAMAIHghE4EIzAgWCNr6qa\n6siRI31PWJF9+/b1PSHWq6++2veExi5cuNDoOM7gQDACB4IROBCMwIFgBA4EI3AgGIEDwQgcCEbg\nQDACB4IROBCMwIFgBA4EI3AgGIEDwQgcCEbgQLCxgdu+xvZHtj+1fdr2Y6sxDMDkmlyy6VdJu6vq\nF9trJb1v+59V9a+OtwGY0NjAq6ok/TK6uXb0UV2OAtCORs/Bbc/YPiXpnKSjVfVht7MAtKFR4FX1\nW1VtkzQraaftW/54jO0528dtH297JICrs6JX0avqR0nHJO25zN/NV9WOqtrR0jYAE2ryKvpG29eO\nPl8v6TZJX3Q9DMDkmryKfp2kI7ZntPQN4ZWqerPbWQDa0ORV9M8kbV+FLQBaxjvZgGAEDgQjcCAY\ngQPBCBwIRuBAMAIHghE4EIzAgWAEDgQjcCAYgQPBCBwIRuBAMAIHghE4EKzJFV2i3XDDDX1PwJR4\n+umn+57Q2Pnz5xsdxxkcCEbgQDACB4IROBCMwIFgBA4EI3AgGIEDwQgcCEbgQDACB4IROBCMwIFg\nBA4EI3AgGIEDwQgcCNY4cNsztj+x/WaXgwC0ZyVn8IOSFroaAqB9jQK3PStpr6Rnup0DoE1Nz+BP\nSnpY0u8dbgHQsrGB275T0rmqOjHmuDnbx20fb20dgIk0OYPvknSX7W8kvSxpt+0X/nhQVc1X1Y6q\n2tHyRgBXaWzgVfVoVc1W1RZJ+yW9U1X3dL4MwMT4OTgQbEW/2aSqjkk61skSAK3jDA4EI3AgGIED\nwQgcCEbgQDACB4IROBCMwIFgBA4EI3AgGIEDwQgcCEbgQDACB4IROBCMwIFgBA4Ec1W1f6f2eUn/\naflu/yLpvy3fZ5eGtHdIW6Vh7e1q61+rauO4gzoJvAu2jw/piq1D2jukrdKw9va9lYfoQDACB4IN\nKfD5vges0JD2DmmrNKy9vW4dzHNwACs3pDM4gBUaROC299j+0vYZ24/0vedKbD9n+5ztz/veMo7t\nzbbftb1g+7Ttg31vWo7ta2x/ZPvT0dbH+t7UhO0Z25/YfrOPf3/qA7c9I+kpSbdL2irpgO2t/a66\noucl7el7REOLkh6qqr9JulXSP6b4/+2vknZX1d8lbZO0x/atPW9q4qCkhb7+8akPXNJOSWeq6uuq\nuqil33B6d8+bllVV70n6oe8dTVTV91V1cvT5z1r6QtzU76rLqyW/jG6uHX1M9QtItmcl7ZX0TF8b\nhhD4JknfXnL7rKb0i3DIbG+RtF3Sh/0uWd7o4e4pSeckHa2qqd068qSkhyX93teAIQTuy/zZVH/n\nHhrbGyS9JunBqvqp7z3LqarfqmqbpFlJO23f0vem5di+U9K5qjrR544hBH5W0uZLbs9K+q6nLXFs\nr9VS3C9W1et972miqn7U0m+5nebXOnZJusv2N1p6Wrnb9gurPWIIgX8s6Ubb19teJ2m/pDd63hTB\ntiU9K2mhqp7oe8+V2N5o+9rR5+sl3Sbpi35XLa+qHq2q2araoqWv2Xeq6p7V3jH1gVfVoqQHJL2t\npReBXqmq0/2uWp7tlyR9IOlm22dt39/3pivYJeleLZ1dTo0+7uh71DKuk/Su7c+09E3/aFX18qOn\nIeGdbECwqT+DA7h6BA4EI3AgGIEDwQgcCEbgQDACB4IROBDsf/DAx4Xphc/GAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show(x_imgs[0,10:15,10:15])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAAF0CAYAAADGqzQSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xu4lnPa//Hzq1raktCWSkU9QiV7a2ZsSqWO0siMyNCO\nEkY2ox0ifjFjVERmJrIrtLfJblI26UGm6FdKZSZFO0raqrVa398f93qOp1/nebeu1X2v+1r3+r5f\nx9ERH9fmxLWudXat+7y+znsvAAAAQIgOi7sAAAAAIC40wwAAAAgWzTAAAACCRTMMAACAYNEMAwAA\nIFg0wwAAAAgWzXDMnHPtnHNznHMbnHN7nHPfOecmO+dOjrs24GCcc+c75951zm1yzm1zzi10zvWK\nuy4gCufcpc65D51zOwqv38+dcxfFXRdwMM65C51z85xzu51zW5xzLzjnasVdV7ajGY5fDRH5l4jc\nJCKXiMhgEWkuIp845xrEWRiQjHPuNBGZLSIVRKSviFwuIgtE5GnnXP84awOK4py7QURelcS9t6uI\nXCEiU0Skcpx1AQfjnPuViLwrIlslcc/9o4j8WkTec84dHmdt2c6x6Ebp45xrKiLLReQO7/1f464H\nOJBz7v+IyB0iUsN7v2O//BMR8d77c2MrDjgI51xDEVkmIoO996PjrQaIzjk3W0Qaikgz731+YXam\niHwmIgO890/GWF5W48lw6bS58Pe8WKsAksuRxPW5+4B8q3BfQenWS0QKROSpuAsBiukcEfnn/zTC\nIiLe+wWS6Bm6xlZVGcA3rVLCOVfOOZfjnDtRRP4mIhtE5OWYywKSebbw98ecc3Wdc9Wdc31F5GIR\nGRVfWUCRciXxk7crnXPfOOfynXOrnHMD4i4MKMI+Edlr5HtE5JQM11Km8DGJUsI597mItC7821Ui\n0tl7vyzGkoCDKvzx3AwRqVcY5YlIf+/90/FVBRycc265iNSVRAMxRES+kcRnhvuJyK3e+zExlgck\n5Zz7TBIfQzt7v6yBiPxHRPK893xu+BDRDJcSzrn/EpEjRKSRJD6LWUtEcr33q+OsC7AU/gTjPUl8\n9vJxSXxcoouI9BeR67z3E2MsD0jKObdCRE4Ukcu999P3y98SkVYiUsfzjRGlkHPuahF5UUQeFJHH\nJDGA/3cROU8SzXClGMvLajTDpZBzrrqIrBaRl733/WIuB1Ccc1NE5HRJDHLk7ZdPFJF2IlLTe18Q\nV31AMs65/5bEZy+P8N5v3y8fKCKPikg97/26uOoDDsY5N0ISD8wqiogXkVdEpIqInOK9bxRnbdmM\nzwyXQt77rZL4qESTuGsBkjhVRL7cvxEu9JmIHC0iNTNfEhDJ0iS5K/ydP8Sh1PLe3y0ix4jIaZL4\nKUZ3SfykY16shWU5muFSqPAF2s0k8Vk2oDTaICItnXM5B+Rni8gvIrIl8yUBkcwo/L3dAXk7EfnO\ne78hw/UAxeK93+m9/7/e+43OufaS6Bd4O0oKysddQOicczNEZKGILBaRbSJykogMFJF8EeEdwyit\nxkpikYLXnXNPSuIzw51FpLuIjPLeWxPPQGnwpojMFZG/OeeOEZF/i0g3SSx61DPOwoCDcc61EpEO\nkugZRBJvRrlTRP7svZ8fW2FlAJ8Zjplz7i4R+Z2INJbEu1vXisj7IjKS4TmUZs65DiJylyRWTKwo\niZ9k/F1E/ua93xdnbcDBOOeOEJGRkmiCj5LEq9Ye8t5PirUw4CCcc80l8erVU0TkcCkcYPbeT4i1\nsDKAZhgAAADB4jPDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWBl9tZpzjmk9pMx774reKr24dpEO\nmb52uW6RDtxzka2iXrs8GQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLJph\nAAAABItmGAAAAMGiGQYAAECwMroCHQAApUnVqlVV1rt3b5V16dLF3L9z584q27FjR+qFAcgYngwD\nAAAgWDTDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWLxNAgAQrGuvvVZlo0aNirx/8+bNVfbpp5+m\nVBOAzOLJMAAAAIJFMwwAAIBg0QwDAAAgWDTDAAAACBYDdClo0aKFygYOHGhu27hxY5VVrlxZZUOG\nDFHZkUceqbK33nrLPM/27dvNHABCd91116ls9OjRKsvLy1PZI488Yh5z4cKFKdcFIF48GQYAAECw\naIYBAAAQLJphAAAABItmGAAAAMFy3vvMncy5zJ0szapWraqyNWvWqKx69eqZKEe+//57M7cG+KZO\nnVrS5WSU995l+pzZfO1arOu0a9eu5ratWrVSWW5ursqsr5EtW7aorHbt2uZ5NmzYoLJnn31WZf/4\nxz9Utm/fPvOYpU2mr92ydt0WR+fOnVU2Y8YMle3atUtl99xzj8qKsypdWcM9F9kq6rXLk2EAAAAE\ni2YYAAAAwaIZBgAAQLBohgEAABAsBugiqlatmsrefPNNlW3evNncf9GiRSqzBpMaNGigsuOPP15l\nlSpVMs+zceNGlZ177rmRtssWDHMUz3HHHaeymTNnqsy6HpPZtm2byqxrvEKFCiqzvpZERGrWrKmy\nWrVqqeyqq65S2Ycffqiy9evXm+eJEwN06ZeTk2PmEyZMUFn37t1VNmfOHJW1adMm9cLKEO65yFYM\n0AEAAABFoBkGAABAsGiGAQAAECyaYQAAAASLAboscMwxx6jszjvvNLe18p49e6rsueeeS72wmDDM\nUTwLFy5UWYsWLVQ2e/Zsc//bb79dZT/++KPKrBXkiuPYY49V2VtvvaWypk2bqmzQoEEqe+KJJ1Kq\npyQwQJd+Q4cONfMRI0ao7MUXX1RZr169VJafn596YWUI99zU1alTR2U33nijua2V5+XlqcxaBffB\nBx9UmfU9QERk7dq1Zl6WMEAHAAAAFIFmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIu3SWSpzp07\nm7m1zO5jjz2msltvvTXtNWUKk83JWRPL33//vcomT56ssquvvto85r59+1Iv7BBNnDhRZVdeeaXK\nWrdurbIvvviiRGpKBW+TSM0ZZ5yhsnnz5pnbrl69WmXNmzdXWZzXd7bgnls8jRo1Utm4ceNU1rZt\n20yUI3v27DHz888/X2XJ3jyRrXibBAAAAFAEmmEAAAAEi2YYAAAAwaIZBgAAQLDKx10AinbUUUep\nbMiQIZH3r1u3bjrLQSnWsmVLlTmn5wfWrVunsrgHic455xyVde/eXWVz585VmfXvXRoH6BDdYYfp\nZzXWsts5OTnm/q+//rrK4r7GUfbUq1dPZUuWLFFZ+fK63Ro1apR5zMcffzzSeZo1a6ayv/zlLyqr\nXr26eR5rkNq6D//444/m/mUJT4YBAAAQLJphAAAABItmGAAAAMGiGQYAAECwWIGulGnRooXKpkyZ\norImTZqY+69YsUJl1io3a9euPYTqSgdWQyqegoIClW3atEllZ511lrn/mjVr0lpPtWrVzHz+/Pkq\nW7lypcqslfKsFZ+WLl16CNWVLFagiy7qaorJ3HLLLSobO3ZsSjWFintucmPGjFFZv379VNa3b1+V\nPf/882mvZ8CAASobPXq0uW25cuVUtnz5cpVZQ3Xbtm07hOoyjxXoAAAAgCLQDAMAACBYNMMAAAAI\nFs0wAAAAgsUAXYyuvfZald1///0qO/7441W2e/du85idOnVSmbViVzZjmKN4hg8frrK7775bZV9/\n/bW5f7t27VSWygDmu+++a+a/+c1vVNa6dWuVWas7ZQsG6KLr2bOnyp5++mmVzZ4929y/Q4cOKmMF\nukPDPVfkiCOOMHNryHfChAkqs1ZPzJRk9/YTTzwx0v7WSnm33357SjVlCgN0AAAAQBFohgEAABAs\nmmEAAAAEi2YYAAAAwaIZBgAAQLB4m0SaVa1a1czvuOMOlQ0bNkxlhx2m/3yyZcsWleXm5prnsZZS\nLGuYbC6eihUrquy5555TWbdu3cz9V61apbILLrhAZevXr1fZk08+qbLrr7/ePM+dd96pMmuKOZvx\nNglb+fLlVbZs2TKVNWjQQGUnnHCCecziLN2Mg+Oem3y5+k8++URlbdu2Vdl7772X9pqi6tq1q5lP\nnz5dZVZPuHXrVpVZb6LYvHnzIVRXsnibBAAAAFAEmmEAAAAEi2YYAAAAwaIZBgAAQLD01AJS8uyz\nz5r5b3/720j7T506VWWjR49WWQiDckiPX375RWV9+vRRWc2aNc39rWWSP/jgA5VNmTJFZT169FDZ\ntGnTzPOUtWE5RGcNbzZu3Fhl/fv3V1ncg3Lt27dXWefOnVX29ttvq8xamtz6ekX8WrVqFXnbRYsW\nlWAlxffmm2+auTUcbX3dWdfkzp07Uy+sFOHJMAAAAIJFMwwAAIBg0QwDAAAgWDTDAAAACBYDdGlm\nffi8OMaNG6ey+fPnp3RM4EDbt29XWZcuXcxthw8frrJbb71VZYMGDYp07scffzzSdghH/fr1I22X\nk5NTwpUkd91115m5tcqitepjv379VGat7DVz5kzzPL169SqiQpSkefPmmXlBQYHK/vnPf6qsU6dO\nKrNW7SwJTZs2NXPrOm3Xrp3KKleurLKyNujJk2EAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAsBujS\nzFpRSESkRYsWh7y/NVT30EMPmfuvW7cu0nmAA23bts3M77nnHpW1bdtWZSeffHKk87Rp08bMkw2o\noOxr0qRJpO0ytfJm9erVVfboo4+a21pDSPn5+Sqzhqpyc3NVZq3aKMIAXdyWLl1q5m+88YbKrGHk\nZcuWqcxalVDEXqVzzpw5KqtXr57KrGE5axVbEZE6deqozLp2X331VXP/soQnwwAAAAgWzTAAAACC\nRTMMAACAYNEMAwAAIFjOe5+5kzmXuZPFpFKlSmb+4osvqqx169Yqi7oS04YNG8y8Z8+eKnvnnXci\nHTNbeO9dps8ZwrWbTIcOHVQ2Y8YMlVWoUCHS8fbu3WvmN954o8omTJgQ6ZjZItPXbrZct7NmzVJZ\nq1atVFa3bt1MlGOusJhsgM66t48ZM0Zla9asUZk1QHXqqaea54lz9T3uuclZ3/NHjhypsltuuSWl\n82zZskVlNWrUSOmYliuuuEJl1kBftoh67fJkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAA\nAMFiOeY02717t5lfffXVKitfXv/nT7Yk7oFq165t5taU/2233aayp556KtJ5gAsvvFBl1ltounbt\nqjJrAtpavlTEXnb8xx9/VNnrr79u7o/sdfbZZ6ss2VtHSpt169ap7LjjjlPZ3//+d5WdfvrpKitr\nb/8p66zv+dbbSCZPnqwyqy9IplatWpG2y8vLU5n19SUicsIJJ6hs165dkWsqS3gyDAAAgGDRDAMA\nACBYNMMAAAAIFs0wAAAAgsVyzKXMaaedprJRo0apzBpqSsZaBrRhw4bFqqs0YWnQkmFdeyIiCxYs\nUJk17GYNjVis5T5FRJ5++mmVOaf/Vzdv3lxl1jVeGrEcs80aLuvUqZPKSmI5Zusas67lv/71rymd\nx/pe++STT6psyJAh5v7bt29P6fyp4J6b3V544QUztwb42rdvr7J333037TVlCssxAwAAAEWgGQYA\nAECwaIYBAAAQLJphAAAABIsV6CKqXLmyykpipZbFixerrFu3bip75plnzP27dOmisvr166usTp06\nKlu/fn2UElFGVatWzcytlRKnTp16yOeZMmWKmTdo0EBlDz/8sMpat26tsmwZoEN01atXV5k1CPTi\niy+a+1vX7ZVXXqmyGjVqqKxDhw5RShQRkZ07d6ps3rx5Kvvzn/+ssrlz50Y+D5AJjRs3jruEWPBk\nGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIsBOoP1AXJrIGLWrFkqW7JkiXlMazitd+/eKqtQoYLK\n6tWrp7ImTZqY57F88803kepB2Fq2bGnmGzZsUJn19ZCqsWPHqqxv374qGzBggMpmzJiR9nqQOYsW\nLVJZnz59VGatmGVlqdq2bZvKkg1+PvDAAyr79ttv014TcKh27NgRdwmlHk+GAQAAECyaYQAAAASL\nZhgAAADBohkGAABAsBigM1xxxRUqq127tsp69eqV9nM751TmvY+8v/VB+X79+qVUE8JgrVQoIvLZ\nZ59l5Px79+5V2U8//aSyX/3qVyqzVhHbsmVLegpDiZs0aZLKrJU3V65cqbJy5cqZx0yWH2jixIkq\nW716tcqsQWQgG3z44YdmfsMNN6isZs2aJV1OqcSTYQAAAASLZhgAAADBohkGAABAsGiGAQAAECya\nYQAAAASLt0kYjj766LhL+P9MmzZNZSNGjDC33bRpk8qs5XSBAyV7a0lubq7KrrzySpXNmTNHZVWr\nVlVZTk6OeZ5mzZqp7Mwzz1TZE088oTLeHJHdfv75Z5VdfPHFMVQClD2HHWY/97TeXmX1ECHgyTAA\nAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWA3SGIUOGqGz27Nkq69Gjh8rq1q1rHtMaELE8/vjjKvvo\no49Ulp+fH+l4QFTLli0zc2upY2v53M2bN6usOAN01jDHxx9/rLLhw4eb+wMAtIKCAjNPNjQdIp4M\nAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYDFAZ8jLy1PZO++8EykDstXbb79t5mPHjlWZtSpdy5Yt\nUzr/0KFDVfbMM8+ojNXmAKBkXHLJJSobN25cDJVkFk+GAQAAECyaYQAAAASLZhgAAADBohkGAABA\nsBigAyAiIhs3bjTzP/7xjxmuBACQLjt27Ii8bfnyYbaFPBkGAABAsGiGAQAAECyaYQAAAASLZhgA\nAADBohkGAABAsJz3PnMncy5zJ0OZ5b13mT4n1y7SIdPXLtct0oF7bnarXr26mVtL2+/evVtlVapU\nSXtNmRL12uXJMAAAAIJFMwwAAIBg0QwDAAAgWDTDAAAACBYDdMg6DHMgWzFAh2zEPRfZigE6AAAA\noAg0wwAAAAgWzTAAAACCRTMMAACAYGV0gA4AAAAoTXgyDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAA\ngkUzDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUz\nDAAAgGDRDAMAACBYNMOlhHPuUufch865Hc65bc65z51zF8VdFxCVc+5t55x3zj0Qdy1AMs65Cwqv\n0wN/bY27NuBgnHPtnHNznHMbnHN7nHPfOecmO+dOjru2bFc+7gIg4py7QUTGFv4aIYk/pLQUkcpx\n1gVE5ZzrLiIt4q4DKIZbRGTBfn+fH1chQEQ1RORfIvKkiPwgIvVFZJCIfOKcO9V7/22cxWUzmuGY\nOecaishoEbnTez96v3/0TiwFAcXknKsuIqNEZKCITIq5HCCqZd77T+IuAojKe/+SiLy0f+ac+0xE\nlotINxH5axx1lQV8TCJ+vUSkQESeirsQ4BD9WUSWFt6oAQCZs7nw97xYq8hyNMPxy5XEn+qudM59\n45zLd86tcs4NiLswoCjOuVwR+YOI3Bh3LUAxTXTO7XPObXbOTXLO1Y+7ICAK51w551yOc+5EEfmb\niGwQkZdjLiur8TGJ+NUt/PUXERkiIt+IyBUiMtY5V957PybO4oBknHMVJHEjfsR7/3Xc9QAR/SyJ\nHyd/ICLbRKSVJO69/+2ca+W93xRncUAEn4pI68K/XiUiF3HdpsZ57+OuIWjOuRUicqKIXO69n75f\n/pYkbtJ1PP+TUAo554ZJ4mM+zb33uwszLyIPeu+HxVocUAzOudNF5DMReYhrF6Wdc+6/ROQIEWkk\nIneISC0RyfXer46zrmzGxyTi9z+f9/nnAfm7krjA62S2HKBohT9SHioid4vI4c656oWDdLLf35eL\nr0IgOu/9QhFZISJnxl0LUBTv/TLv/aeFcxoXi0hVSbxVAoeIZjh+S5PkrvD3gkwVAhRDIxGpKCIv\nishP+/0SSTyp+ElETo2nNOCQOBHhp3DIKt77rZL4qESTuGvJZjTD8ZtR+Hu7A/J2IvKd935DhusB\novhCRC40fokkGuQLJXGDBko959wZInKSJD6LCWQN51wtEWkmiXkjHCIG6OL3pojMFZG/OeeOEZF/\nS+J9gZeISM84CwOSKXwa8f6BuXNORORb7736Z0Bp4JybKCL/EZGFIrJVErMZg0XkexF5PMbSgINy\nzs2QxHW7WBLDnydJ4v3u+cI7hlNCMxwz7713zl0mIiNF5D4ROUoSr1q72nvPAgYAkF5LRKS7iNws\niVU+N4jIdBG513v/Y5yFAUX4RER+JyK3i0iOiKyVxEOJkQzPpYa3SQAAACBYfGYYAAAAwaIZBgAA\nQLBohgEAABAsmmEAAAAEK6NvkyhcqhVIiffeFb1VenHtIh0yfe1y3SIduOciW0W9dnkyDAAAgGDR\nDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMA\nACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIVvm4CwhFhw4dVDZw4ECVtW3bVmXe\ne5WtXLnSPM/kyZNVNm7cOJWtW7fO3B8AACAkPBkGAABAsGiGAQAAECyaYQAAAASLZhgAAADBctZw\nVomdzLnMnSwm/fv3N/NRo0apLCcnp6TLERGRuXPnqqxHjx4qW79+fSbKSZn33mX6nCFcuyh5mb52\nuW6RDtxzi+e5555T2TXXXKOyWbNmmftPmzZNZfPnz1fZ2rVrI9Wzd+9eM9+3b1+k/bNZ1GuXJ8MA\nAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYrECXgo4dO6rskUceMbe1huUWLVqkskGDBqls6dKlkWvq\n3bu3yu677z6VDR48WGW33HJL5PMgu1WpUkVlQ4YMMbcdNmyYyqzB2xEjRqisRYsWKuvcuXOUEgEg\nKy1fvlxlBQUFKrN6iIPlh2rChAlmfsMNN6gsPz8/refOFjwZBgAAQLBohgEAABAsmmEAAAAEi2YY\nAAAAwWIFuog6deqkspdeekll1mCSiMjMmTNVZq1Wt3HjxkOo7n85pxdbsYbqLrnkEpX97ne/S+nc\nmcJqSKmrX7++yr799ltz29atW6ts4cKFKrMG6G6++WaVNW3a1DxPqtd+NmAFOuyvVq1aKmvSpIm5\nbcWKFVXWvXt3lU2cOFFlyVYg+/jjj4sqUUS456aD1UO0a9cu8v5nnnmmyqz7eKVKlVR25JFHmse8\n+OKLVWatWJvNWIEOAAAAKALNMAAAAIJFMwwAAIBg0QwDAAAgWKxAZyhfXv9nsVZxs4blFi9ebB7T\nWunlhx9+OITqDs4aiBw/frzKZsyYkfZzI3s0bNgw7cfMy8tTmTW4cfLJJ5v7hzBAhzCccsopKvv9\n73+vsl69eqmsTp065jGjDrv37Nkz0nYiIuXKlYu8LVLzxhtvRMpS1aFDB5XNmjXL3PbSSy9VWVkb\noIuKJ8MAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFm+TMPTt21dlrVq1UtmePXtUdt11\n15nHLIk3R6Ri8+bNcZeAGJ177rlpP+arr76qMustLGeccYa5f6hTzMgOLVu2NPOBAweqrE2bNiqr\nXbt22muybN++XWVz5szJyLmRWTVq1FDZvffeq7L8/Hxz/2RvmQgRT4YBAAAQLJphAAAABItmGAAA\nAMGiGQYAAECwGKAz3HzzzZG269evn8q++OKLdJcDpMRacvXyyy9XWUFBgbl/suELoLispe5FRCpW\nrKiyHTt2lHQ5ImIPdE6YMEFljRs3Nvc//PDD016T5auvvlLZsGHDVGYNR8+bN69EakJqqlWrZua5\nubkqy8nJUdnQoUNVZl3Pzz//vHme999/v4gKw8GTYQAAAASLZhgAAADBohkGAABAsGiGAQAAECwG\n6FLw3XffxV0CUKRatWqp7Mwzz1TZf/7zH3P/xYsXRzpPXl6eyvbt26eyJk2aRDoeyh5rdSwRkcsu\nu0xl06ZNU9nw4cMjn+u0005T2V133aUya5i0QoUKKnPOmefx3keuKQrr31tE5A9/+IPKdu/endZz\nIz2qVq2qspEjR6rMuvZEUlut8NNPP1XZQw89dMjHCwVPhgEAABAsmmEAAAAEi2YYAAAAwaIZBgAA\nQLCCHqCzBixERE488USVbd++XWVff/112msC4rJy5cqU9l+1apXK1q5dq7KWLVumdB5khyOOOEJl\n11xzjblt/fr1Vda8eXOVWYNJTZs2NY/ZsWPHokoslmQDdBZrFbgXXnhBZdOnT1cZq8Vlv/PPP19l\nAwYMyMi5ra+RZKuL4n/xZBgAAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgAAADBCvptEuXL2//6\n5cqVU9muXbtUxnLMyAYXXXRRpO1GjRqV0nmsryfra6lOnTrm/tbbB7Zt25ZSTYhPjRo1VFalShVz\n26hLGg8cOFBlJbFM8oIFC1T2yiuvmNu++eabKtuxY4fKvv/++0OuB9klNzc3pf03bdqksnHjxqns\nsMP088y7775bZdZS0CIiffr0UdlPP/0UpcQyhyfDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWEEP\n0MXt6KOPVlmnTp1Udvvtt0c+5urVq1XWsGFDlW3YsEFlU6dOVdmECRPM8+Tl5UWuCfE677zzVLZx\n40aVffTRRymdxxoynTVrlsr69etn7n/kkUeqjAG67GXdi3744QdzW2vYLlNGjBihsscee0xlW7Zs\nyUQ5KAPuu+8+lf3rX/9S2c6dO839P/jgA5Xt3btXZdbw6JQpU1T23nvvmecZP368ynr37q2yrVu3\nmvuXJTwZBgAAQLBohgEAABAsmmEAAAAEi2YYAAAAwWKALiJrwOOMM85Q2eeff27u36RJE5XNnj1b\nZfXr11fZ7t27Vfbll1+a57GGVqysZ8+eKmvTpo3K2rVrZ57n8ssvN3PEy1rh69JLL1WZNYyRbJgj\nFSEMXiAisjLlAAAIkklEQVS6ZIM8TZs2PeRjfvjhh2Y+bdo0lU2aNEll1opbBQUFh1wPkJ+fr7KZ\nM2em/TzWKotLlixRWd++fc39Z8yYobK5c+eqbOzYsYdQXXbhyTAAAACCRTMMAACAYNEMAwAAIFg0\nwwAAAAhW0AN0yVYU+vnnn1VmrY5lZY0aNTKPOWfOHJUdd9xxKrMGTAYMGKCyFStWmOeJ6rXXXlOZ\n9WH6Zs2apXQeZFblypVV1qBBA5WtXbs2E+WYX0vJWF9PmaoTmTF48GAzt1betIaJLRdccEEqJQFl\nnvX9XkTk5ZdfVpn1NfrKK6+oLNlqktmKJ8MAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYQQ/QWSuz\niYisX79eZdZwz1VXXaWyk08+2TymNSxnrUDXtWtXlZXEymDWucePH6+ySy65JO3nRvxycnJU1rp1\na3PbX375RWXW8GmlSpVUZq2QlMy4ceNUdtFFF6ksLy8v8jFRuuzYscPMrUGeHj16qKxevXoq27Bh\ng3nMKVOmqOzee+9VWbJBaqCsGzNmjMq6d++usuuvv15lDz74YInUFBeeDAMAACBYNMMAAAAIFs0w\nAAAAgkUzDAAAgGDRDAMAACBYrjjT3imfzLnMnSwFI0eOVNldd92V0jGtNzXceuutKtu1a1dK50nF\npEmTVNa+fXtz25YtW6pszZo1aa/J4r13GTnRfrLl2j322GNVtmnTppSOmZ+frzLrrQDWGyqs5aGL\nw3q7ysyZM1M6Zpwyfe1my3Vrsaban3rqKZVVq1bN3N/63jZ//nyVde7cWWU//fRTlBKDwT23bKpY\nsaLKPv74Y5UtXrxYZT179iyRmtIt6rXLk2EAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAsBugM1atX\nV9kXX3yhsvr160c+5m233aay0aNHF6+wEmYta5psOOX0009X2ddff532miwMcyRXrlw5lY0YMUJl\ngwcPzkQ5xfL555+r7JxzzlHZvn37MlFOiWCALjXWPdcaThYRufjiiyMd86uvvlLZFVdcobLly5dH\nOl5ZxD3XXgpcxB7q7Natm8r27NmT9ppKwrBhw1R2ww03qOzUU09V2datW0ukplQwQAcAAAAUgWYY\nAAAAwaIZBgAAQLBohgEAABAsBugi6tixo8pefvlllVWpUsXcf+fOnSp74403VPbggw+qbMmSJVFK\nLJYOHTqo7LXXXlPZihUrzP2bN2+e9pqiYpijeKyhupo1a6os2bVrXSvW0JGVWUMW77zzjnkea3Ww\n888/39w2WzFAl37WkKWIvVKhtUKjZcGCBSq76aabzG2twc+yhnuuSMOGDc383//+t8peeOEFlf3p\nT39S2caNG1OuK92sAbr7779fZY0aNVLZ6tWrS6KklDBABwAAABSBZhgAAADBohkGAABAsGiGAQAA\nECwG6FLQrl07lT388MPmtqeddlqkY+7evVtlffr0UdmaNWvM/a0PsOfm5qpszJgxKrNW3nvppZfM\n8/Ts2dPMM4FhjuzRunVrlSUbOGKALv1Cvm4vu+wylU2bNu2Qj2fdh0VEJkyYcMjHzBbcc0Xq1q1r\n5tbKq9Yw8sqVK1XWr18/85gfffSRyvLz84sqsdi6du2qskceeURlOTk5KjvllFNU9vPPP6ensDRi\ngA4AAAAoAs0wAAAAgkUzDAAAgGDRDAMAACBYDNClWbIVjnr16qUya0Wao446Ku01WawP41ur3913\n332ZKKdYGObIHsccc4zKli9fbm67b98+lZ100kkqK41DGlExQJd+/fv3N/Mnnngired59tlnzdy6\nt5c13HOT69atm8omT56c0jGtlemsXu3VV19VWZcuXSKfp0aNGiqzhuUeeOABld1zzz2RzxMnBugA\nAACAItAMAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYPE2iRhZk5zWZLQ1rdqiRYvI51m7dq3Knnrq\nKZWNHDky8jHjxGRzdrOWXRYROffcc1VmLYG6fv36tNeUKbxNIjprufvBgwer7Ne//rW5f7q/t910\n001mPm7cuLSepzTinptcuXLlVNa+fXuVDRo0SGWpLjfvnP7fkup1P378eJUNHTpUZT/88ENK58kU\n3iYBAAAAFIFmGAAAAMGiGQYAAECwaIYBAAAQLAbokHUY5shuAwcONPNHH31UZZdddpnKrCVIs0Xo\nA3QdOnQw8+uvv15l1hCStVSsNUQkEn2QaMSIESpbuHChyl577bVIxyuLuOem7rDD9LPHs846y9zW\nGpo/77zzVHbOOeeobO/evSqbMmWKeZ4xY8aozLr2CwoKzP2zAQN0AAAAQBFohgEAABAsmmEAAAAE\ni2YYAAAAwWKADlmHYY7sdvbZZ5v5J598orL3339fZRdeeGG6S8qYkAbo+vTpo7Jkq1xaq3Fatm7d\nqrJ58+aZ23755Zcqmz59usoWL16ssmweGCoJ3HORrRigAwAAAIpAMwwAAIBg0QwDAAAgWDTDAAAA\nCBYDdMg6DHMgW4U0QGetmNWxY0dz21mzZkU65qZNm1S2atWq4hWGYuOei2zFAB0AAABQBJphAAAA\nBItmGAAAAMGiGQYAAECwaIYBAAAQLN4mgazDZDOyVUhvk0DZwT0X2Yq3SQAAAABFoBkGAABAsGiG\nAQAAECyaYQAAAASLZhgAAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgAAADByugKdAAAAEBpwpNh\nAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAA\nBItmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItm\nGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECw/h+g8MKF0kXt3QAAAABJRU5E\nrkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plots(x_imgs[:8], titles=y_valid[:8])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural Networks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will take a deep look *logistic regression* and how we can program it ourselves. We are going to treat it as a specific example of a shallow neural net."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**What is a neural network?**\n",
"\n",
"A *neural network* is an *infinitely flexible function*, consisting of *layers*. A *layer* is a linear function such as matrix multiplication followed by a non-linear function (the *activation*).\n",
"\n",
"One of the tricky parts of neural networks is just keeping track of all the vocabulary! "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Functions, parameters, and training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A **function** takes inputs and returns outputs. For instance, $f(x) = 3x + 5$ is an example of a function. If we input $2$, the output is $3\\times 2 + 5 = 11$, or if we input $-1$, the output is $3\\times -1 + 5 = 2$\n",
"\n",
"Functions have **parameters**. The above function $f$ is $ax + b$, with parameters a and b set to $a=3$ and $b=5$.\n",
"\n",
"Machine learning is often about learning the best values for those parameters. For instance, suppose we have the data points on the chart below. What values should we choose for $a$ and $b$?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the above gif from fast.ai's deep learning course, [intro to SGD notebook](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/sgd-intro.ipynb)), an algorithm called stochastic gradient descent is being used to learn the best parameters to fit the line to the data (note: in the gif, the algorithm is stopping before the absolute best parameters are found). This process is called **training** or **fitting**.\n",
"\n",
"Most datasets will not be well-represented by a line. We could use a more complicated function, such as $g(x) = ax^2 + bx + c + \\sin d$. Now we have 4 parameters to learn: $a$, $b$, $c$, and $d$. This function is more flexible than $f(x) = ax + b$ and will be able to accurately model more datasets.\n",
"\n",
"Neural networks take this to an extreme, and are infinitely flexible. They often have thousands, or even hundreds of thousands of parameters. However the core idea is the same as above. The neural network is a function, and we will learn the best parameters for modeling our data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will be using the open source [deep learning library, fastai](https://github.com/fastai/fastai), which provides high level abstractions and best practices on top of PyTorch. This is the highest level, simplest way to get started with deep learning. Please note that fastai requires Python 3 to function. It is currently in pre-alpha, so items may move around and more documentation will be added in the future.\n",
"\n",
"The fastai deep learning library uses [PyTorch](http://pytorch.org/), a Python framework for dynamic neural networks with GPU acceleration, which was released by Facebook's AI team.\n",
"\n",
"PyTorch has two overlapping, yet distinct, purposes. As described in the [PyTorch documentation](http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html):\n",
"\n",
"\n",
"\n",
"The neural network functionality of PyTorch is built on top of the Numpy-like functionality for fast matrix computations on a GPU. Although the neural network purpose receives way more attention, both are very useful. We'll implement a neural net from scratch today using PyTorch.\n",
"\n",
"**Further learning**: If you are curious to learn what *dynamic* neural networks are, you may want to watch [this talk](https://www.youtube.com/watch?v=Z15cBAuY7Sc) by Soumith Chintala, Facebook AI researcher and core PyTorch contributor.\n",
"\n",
"If you want to learn more PyTorch, you can try this [introductory tutorial](http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) or this [tutorial to learn by examples](http://pytorch.org/tutorials/beginner/pytorch_with_examples.html)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### About GPUs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Graphical processing units (GPUs) allow for matrix computations to be done with much greater speed, as long as you have a library such as PyTorch that takes advantage of them. Advances in GPU technology in the last 10-20 years have been a key part of why neural networks are proving so much more powerful now than they did a few decades ago. \n",
"\n",
"You may own a computer that has a GPU which can be used. For the many people that either don't have a GPU (or have a GPU which can't be easily accessed by Python), there are a few differnt options:\n",
"\n",
"- **Don't use a GPU**: For the sake of this tutorial, you don't have to use a GPU, although some computations will be slower.\n",
"- **Use crestle, through your browser**: [Crestle](https://www.crestle.com/) is a service that gives you an already set up cloud service with all the popular scientific and deep learning frameworks already pre-installed and configured to run on a GPU in the cloud. It is easily accessed through your browser. New users get 10 hours and 1 GB of storage for free. After this, GPU usage is 34 cents per hour. I recommend this option to those who are new to AWS or new to using the console.\n",
"- **Set up an AWS instance through your console**: You can create an AWS instance with a GPU by following the steps in this [fast.ai setup lesson](http://course.fast.ai/lessons/aws.html).] AWS charges 90 cents per hour for this."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural Net for Logistic Regression in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.metrics import *\n",
"from fastai.model import *\n",
"from fastai.dataset import *\n",
"\n",
"import torch.nn as nn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will begin with the highest level abstraction: using a neural net defined by PyTorch's Sequential class. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net = nn.Sequential(\n",
" nn.Linear(28*28, 100),\n",
" nn.ReLU(),\n",
" nn.Linear(100, 100),\n",
" nn.ReLU(),\n",
" nn.Linear(100, 10),\n",
" nn.LogSoftmax()\n",
").cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each input is a vector of size `28*28` pixels and our output is of size `10` (since there are 10 digits: 0, 1, ..., 9). \n",
"\n",
"We use the output of the final layer to generate our predictions. Often for classification problems (like MNIST digit classification), the final layer has the same number of outputs as there are classes. In that case, this is 10: one for each digit from 0 to 9. These can be converted to comparative probabilities. For instance, it may be determined that a particular hand-written image is 80% likely to be a 4, 18% likely to be a 9, and 2% likely to be a 3."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"md = ImageClassifierData.from_arrays(path, (x,y), (x_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss=nn.NLLLoss()\n",
"metrics=[accuracy]\n",
"# opt=optim.SGD(net.parameters(), 1e-1, momentum=0.9)\n",
"opt=optim.SGD(net.parameters(), 1e-1, momentum=0.9, weight_decay=1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loss functions and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In machine learning the **loss** function or cost function is representing the price paid for inaccuracy of predictions.\n",
"\n",
"The loss associated with one example in binary classification is given by:\n",
"`-(y * log(p) + (1-y) * log (1-p))`\n",
"where `y` is the true label of `x` and `p` is the probability predicted by our model that the label is 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def binary_loss(y, p):\n",
" return np.mean(-(y * np.log(p) + (1-y)*np.log(1-p)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.164252033486018"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"acts = np.array([1, 0, 0, 1])\n",
"preds = np.array([0.9, 0.1, 0.2, 0.8])\n",
"binary_loss(acts, preds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that in our toy example above our accuracy is 100% and our loss is 0.16. Compare that to a loss of 0.03 that we are getting while predicting cats and dogs. Exercise: play with `preds` to get a lower loss for this example. \n",
"\n",
"**Example:** Here is an example on how to compute the loss for one example of binary classification problem. Suppose for an image x with label 1 and your model gives it a prediction of 0.9. For this case the loss should be small because our model is predicting a label $1$ with high probability.\n",
"\n",
"`loss = -log(0.9) = 0.10`\n",
"\n",
"Now suppose x has label 0 but our model is predicting 0.9. In this case our loss is should be much larger.\n",
"\n",
"`loss = -log(1-0.9) = 2.30`\n",
"\n",
"- Exercise: look at the other cases and convince yourself that this make sense.\n",
"- Exercise: how would you rewrite `binary_loss` using `if` instead of `*` and `+`?\n",
"\n",
"Why not just maximize accuracy? The binary classification loss is an easier function to optimize.\n",
"\n",
"For multi-class classification, we use *negative log liklihood* (also known as *categorical cross entropy*) which is exactly the same thing, but summed up over all classes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fitting the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Fitting* is the process by which the neural net learns the best parameters for the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2869810cce3b487aa6ad0a0ebca20426",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.26978 0.25134 0.93501] \n",
"[ 1. 0.26613 0.22294 0.93909] \n",
"[ 2. 0.21344 0.21292 0.94865] \n",
"[ 3. 0.19153 0.20118 0.9582 ] \n",
"[ 4. 0.20079 0.34337 0.9377 ] \n",
"\n"
]
}
],
"source": [
"fit(net, md, n_epochs=5, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_lrs(opt, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "564018b309a9494fba10c8e89873cd0b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.07224 0.13539 0.97024] \n",
"[ 1. 0.05924 0.13015 0.97114] \n",
"[ 2. 0.04723 0.1316 0.97124] \n",
"\n"
]
}
],
"source": [
"fit(net, md, n_epochs=3, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ea801e9b44b74358ab227a9c485a69e8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.27253 0.21465 0.93939] \n",
"[ 1. 0.21963 0.23439 0.93481] \n",
"[ 2. 0.23288 0.18333 0.94705] \n",
"[ 3. 0.19822 0.1902 0.94636] \n",
"[ 4. 0.205 0.27594 0.92227] \n",
"\n"
]
}
],
"source": [
"fit(net, md, n_epochs=5, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"set_lrs(opt, 1e-2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e3074b642a64df0894778909923b943",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.07269 0.09255 0.97373] \n",
"[ 1. 0.06316 0.08361 0.97572] \n",
"[ 2. 0.04525 0.08077 0.97681] \n",
"\n"
]
}
],
"source": [
"fit(net, md, n_epochs=3, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"([78400, 100, 10000, 100, 1000, 10], 89610)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = [o.numel() for o in net.parameters()]\n",
"t, sum(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"GPUs are great at handling lots of data at once (otherwise don't get performance benefit). We break the data up into **batches**, and that specifies how many samples from our dataset we want to send to the GPU at a time. The fastai library defaults to a batch size of 64. On each iteration of the training loop, the error on 1 batch of data will be calculated, and the optimizer will update the parameters based on that.\n",
"\n",
"An **epoch** is completed once each data sample has been used once in the training loop.\n",
"\n",
"Now that we have the parameters for our model, we can make predictions on our validation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preds = predict(net, md.val_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10000, 10)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Question**: Why does our output have length 10 (for each image)?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3, 8, 6, 9, 6])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds.argmax(axis=1)[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preds = preds.argmax(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check how accurate this approach is on our validation set. You may want to compare this against other implementations of logistic regression, such as the one in sklearn. In our testing, this simple pytorch version is faster and more accurate for this problem!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.91820000000000002"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(preds == y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see how some of our predictions look!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAAF0CAYAAADGqzQSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xu4lmP6//HzUi1tSWhLRZn6iTayt2bGptRSR2lkRmRo\nRwkjxmiHiF/MGJUiYyK7QhtlU3Ypm/RDpugbpTKToh0lbdVarev7x9P3+PbrPJ/WvXqe9dzrWdf7\ndRyODp+57/s6jXs9ne6e876c914AAACAEB0WdwEAAABAXGiGAQAAECyaYQAAAASLZhgAAADBohkG\nAABAsGiGAQAAECya4Zg559o55+Y459Y753Y7575zzk12zp0cd23AwTjnznPOve2c2+ic2+qcW+ic\n6xl3XUAUzrlLnHMfOOe277t/P3POXRh3XcDBOOcucM7Nc87tcs5tds4955yrFXdd2Y5mOH41RORf\nInKjiFwsIoNEpJmIfOycaxBnYUAyzrnmIjJbRCqISB8RuUxEFojIk865fnHWBhTFOXe9iLwiic/e\nLiJyuYhMEZHKcdYFHIxz7tci8raIbJHEZ+6fROQ3IvKuc+7wOGvLdo5NN0of51wTEVkmIn/23v89\n7nqAAznn/q+I/FlEanjvt++Xfywi3nt/TmzFAQfhnGsoIktFZJD3flS81QDROedmi0hDEWnqvS/Y\nl50hIp+KSH/v/WMxlpfVeDJcOm3a92t+rFUAyeVI4v7cdUC+RfhcQenWU0QKReTxuAsBiulsEXnn\nfxphERHv/QJJ9AxdYquqDOA3rVLCOVfOOZfjnDtJRP4hIutF5MWYywKSeXrfr4845+o656o75/qI\nyEUiMjK+soAi5UriT96ucM5945wrcM6tdM71j7swoAh7RWSPke8WkVMyXEuZwtckSgnn3Gci0nrf\n364UkU7e+6UxlgQc1L4/npsuIvX2Rfki0s97/2R8VQEH55xbJiJ1JdFADBaRbyTxneG+InKL9350\njOUBSTnnPpXE19DO2i9rICL/EZF87z3fGz5ENMOlhHPu/4jIESJyoiS+i1lLRHK996virAuw7PsT\njHcl8d3LMZL4ukRnEeknItd67yfGWB6QlHNuuYicJCKXee9f3i9/Q0RaiUgdz2+MKIWcc1eJyPMi\ncr+IPCKJAfwnRORcSTTDlWIsL6vRDJdCzrnqIrJKRF703veNuRxAcc5NEZHTJDHIkb9fPlFE2olI\nTe99YVz1Ack45/6fJL57eYT3ftt++QAReVhE6nnv18ZVH3AwzrnhknhgVlFEvIi8JCJVROQU7/2J\ncdaWzfjOcCnkvd8iia9KNI67FiCJU0Xki/0b4X0+FZGjRaRm5ksCIvkySe72/cp/xKHU8t7fKSLH\niEhzSfwpRjdJ/EnHvFgLy3I0w6XQvhdoN5XEd9mA0mi9iLR0zuUckJ8lIr+IyObMlwREMn3fr+0O\nyNuJyHfe+/UZrgcoFu/9Du/9f3nvNzjn2kuiX+DtKCkoH3cBoXPOTReRhSKyWES2isivRGSAiBSI\nCO8YRmk1VhKbFLzmnHtMEt8Z7iQi3URkpPfemngGSoNZIjJXRP7hnDtGRP4tIl0lselRjzgLAw7G\nOddKRPIk0TOIJN6McruI/NV7Pz+2wsoAvjMcM+fcHSLyexFpJIl3t64RkfdEZATDcyjNnHN5InKH\nJHZMrCiJP8l4QkT+4b3fG2dtwME4544QkRGSaIKPksSr1h7w3k+KtTDgIJxzzSTx6tVTRORw2TfA\n7L2fEGthZQDNMAAAAILFd4YBAAAQLJphAAAABItmGAAAAMGiGQYAAECwMvpqNecc03pImffeFX1U\nenHvIh0yfe9y3yId+MxFtop67/JkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYA\nAECwaIYBAAAQLJphAAAABItmGAAAAMHK6A50AACUJlWrVlVZr169VNa5c2fz/E6dOqls+/btqRcG\nIGN4MgwAAIBg0QwDAAAgWDTDAAAACBbNMAAAAIJFMwwAAIBg8TYJAECwrrnmGpWNHDky8vnNmjVT\n2SeffJJSTQAyiyfDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWAzQpaBFixYqGzBggHlso0aNVFa5\ncmWVDR48WGVHHnmkyt544w1znW3btpk5AITu2muvVdmoUaNUlp+fr7KHHnrIvObChQtTrgtAvHgy\nDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAguW895lbzLnMLZZmVatWVdnq1atVVr169UyUI99//72Z\nWwN8U6dOLelyMsp77zK9ZjbfuxbrPu3SpYt5bKtWrVSWm5urMutnZPPmzSqrXbu2uc769etV9vTT\nT6vsn//8p8r27t1rXrO0yfS9W9bu2+Lo1KmTyqZPn66ynTt3quyuu+5SWXF2pStr+MxFtop67/Jk\nGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIsBuoiqVaumslmzZqls06ZN5vmLFi1SmTWY1KBBA5Ud\nf/zxKqtUqZK5zoYNG1R2zjnnRDouWzDMUTzHHXecymbMmKEy635MZuvWrSqz7vEKFSqozPpZEhGp\nWbOmymrVqqWyK6+8UmUffPCBytatW2euEycG6NIvJyfHzCdMmKCybt26qWzOnDkqa9OmTeqFlSF8\n5iJbMUAHAAAAFIFmGAAAAMGiGQYAAECwaIYBAAAQLAbossAxxxyjsttvv9081sp79OihsmeeeSb1\nwmLCMEfxLFy4UGUtWrRQ2ezZs83zb7vtNpX9+OOPKrN2kCuOY489VmVvvPGGypo0aaKygQMHquzR\nRx9NqZ6SwABd+g0ZMsTMhw8frrLnn39eZT179lRZQUFB6oWVIXzmpq5OnToqu+GGG8xjrTw/P19l\n1i64999/v8qs3wNERNasWWPmZQkDdAAAAEARaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECweJtE\nlurUqZOZW9vsPvLIIyq75ZZb0l5TpjDZnJw1sfz999+rbPLkySq76qqrzGvu3bs39cIO0cSJE1V2\nxRVXqKx169Yq+/zzz0ukplTwNonUnH766SqbN2+eeeyqVatU1qxZM5XFeX9nCz5zi+fEE09U2bhx\n41TWtm3bTJQju3fvNvPzzjtPZcnePJGteJsEAAAAUASaYQAAAASLZhgAAADBohkGAABAsMrHXQCK\ndtRRR6ls8ODBkc+vW7duOstBKdayZUuVOafnB9auXauyuAeJzj77bJV169ZNZXPnzlWZ9c9dGgfo\nEN1hh+lnNda22zk5Oeb5r732msrivsdR9tSrV09lS5YsUVn58rrdGjlypHnNMWPGRFqnadOmKvvb\n3/6msurVq5vrWIPU1ufwjz/+aJ5flvBkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIsd6EqZFi1a\nqGzKlCkqa9y4sXn+8uXLVWbtcrNmzZpDqK50YDek4iksLFTZxo0bVXbmmWea569evTqt9VSrVs3M\n58+fr7IVK1aozNopz9rx6csvvzyE6koWO9BFF3U3xWRuvvlmlY0dOzalmkLFZ25yo0ePVlnfvn1V\n1qdPH5U9++yzaa+nf//+Khs1apR5bLly5VS2bNkylVlDdVu3bj2E6jKPHegAAACAItAMAwAAIFg0\nwwAAAAgWzTAAAACCxQBdjK655hqV3XvvvSo7/vjjVbZr1y7zmh07dlSZtWNXNmOYo3iGDRumsjvv\nvFNlX3/9tXl+u3btVJbKAObbb79t5r/97W9V1rp1a5VZuztlCwboouvRo4fKnnzySZXNnj3bPD8v\nL09l7EB3aPjMFTniiCPM3BrynTBhgsqs3RMzJdln+0knnRTpfGunvNtuuy2lmjKFAToAAACgCDTD\nAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWLxNIs2qVq1q5n/+859VNnToUJUddpj+75PNmzerLDc3\n11zH2kqxrGGyuXgqVqyosmeeeUZlXbt2Nc9fuXKlys4//3yVrVu3TmWPPfaYyq677jpzndtvv11l\n1hRzNuNtErby5curbOnSpSpr0KCByk444QTzmsXZuhkHx2du8u3qP/74Y5W1bdtWZe+++27aa4qq\nS5cuZv7yyy+rzOoJt2zZojLrTRSbNm06hOpKFm+TAAAAAIpAMwwAAIBg0QwDAAAgWDTDAAAACJae\nWkBKnn76aTP/3e9+F+n8qVOnqmzUqFEqC2FQDunxyy+/qKx3794qq1mzpnm+tU3y+++/r7IpU6ao\nrHv37iqbNm2auU5ZG5ZDdNbwZqNGjVTWr18/lcU9KNe+fXuVderUSWVvvvmmyqytya2fV8SvVatW\nkY9dtGhRCVZSfLNmzTJzazja+rmz7skdO3akXlgpwpNhAAAABItmGAAAAMGiGQYAAECwaIYBAAAQ\nLAbo0sz68nlxjBs3TmXz589P6ZrAgbZt26ayzp07m8cOGzZMZbfccovKBg4cGGntMWPGRDoO4ahf\nv36k43Jyckq4kuSuvfZaM7d2WbR2fezbt6/KrJ29ZsyYYa7Ts2fPIipESZo3b56ZFxYWquydd95R\nWceOHVVm7dpZEpo0aWLm1n3arl07lVWuXFllZW3QkyfDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAg\nWAzQpZm1o5CISIsWLQ75fGuo7oEHHjDPX7t2baR1gANt3brVzO+66y6VtW3bVmUnn3xypHXatGlj\n5skGVFD2NW7cONJxmdp5s3r16ip7+OGHzWOtIaSCggKVWUNVubm5KrN2bRRhgC5uX375pZm//vrr\nKrOGkZcuXaoya1dCEXuXzjlz5qisXr16KrOG5axdbEVE6tSpozLr3n3llVfM88sSngwDAAAgWDTD\nAAAACBbNMAAAAIJFMwwAAIBgOe995hZzLnOLxaRSpUpm/vzzz6usdevWKou6E9P69evNvEePHip7\n6623Il0zW3jvXabXDOHeTSYvL09l06dPV1mFChUiXW/Pnj1mfsMNN6hswoQJka6ZLTJ972bLfTtz\n5kyVtWrVSmV169bNRDnmDovJBuisz/bRo0erbPXq1SqzBqhOPfVUc504d9/jMzc56/f8ESNGqOzm\nm29OaZ3NmzerrEaNGild03L55ZerzBroyxZR712eDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDR\nDAMAACBYbMecZrt27TLzq666SmXly+v/+5NtiXug2rVrm7k15X/rrbeq7PHHH4+0DnDBBReozHoL\nTZcuXVRmTUBb25eK2NuO//jjjyp77bXXzPORvc466yyVJXvrSGmzdu1alR133HEqe+KJJ1R22mmn\nqaysvf2nrLN+z7feRjJ58mSVWX1BMrVq1Yp0XH5+vsqsny8RkRNOOEFlO3fujFxTWcKTYQAAAASL\nZhgAAADBohkGAABAsGiGAQAAECy2Yy5lmjdvrrKRI0eqzBpqSsbaBrRhw4bFqqs0YWvQkmHdeyIi\nCxYsUJk17GYNjVis7T5FRJ588kmVOaf/VTdr1kxl1j1eGrEds80aLuvYsaPKSmI7Zuses+7lv//9\n7ymtY/1e+9hjj6ls8ODB5vnbtm1Laf1U8Jmb3Z577jkztwb42rdvr7K333477TVlCtsxAwAAAEWg\nGQYAAECwaIYBAAAQLJphAAAABIsd6CKqXLmyykpip5bFixerrGvXrip76qmnzPM7d+6ssvr166us\nTp06Klu3bl2UElFGVatWzcytnRKnTp16yOtMmTLFzBs0aKCyBx98UGWtW7dWWbYM0CG66tWrq8wa\nBHr++efN86379oorrlBZjRo1VJaXlxelRBER2bFjh8rmzZunsr/+9a8qmzt3buR1gExo1KhR3CXE\ngifDAAAACBbNMAAAAIJFMwwAAIBg0QwDAAAgWAzQGawvkFsDETNnzlTZkiVLzGtaw2m9evVSWYUK\nFVRWr149lTVu3Nhcx/LNN99Eqgdha9mypZmvX79eZdbPQ6rGjh2rsj59+qisf//+Kps+fXra60Hm\nLFq0SGW9e/dWmbVjlpWlauvWrSpLNvh53333qezbb79Ne03Aodq+fXvcJZR6PBkGAABAsGiGAQAA\nECyaYQAAAASLZhgAAADBYoDOcPnll6usdu3aKuvZs2fa13bOqcx7H/l864vyffv2TakmhMHaqVBE\n5NNPP83I+nv27FHZTz/9pLJf//rXKrN2Edu8eXN6CkOJmzRpksqsnTdXrFihsnLlypnXTJYfaOLE\niSpbtWqVyqxBZCAbfPDBB2Z+/fXXq6xmzZolXU6pxJNhAAAABItmGAAAAMGiGQYAAECwaIYBAAAQ\nLJphAAAABIu3SRiOPvrouEv4/0ybNk1lw4cPN4/duHGjyqztdIEDJXtrSW5ursquuOIKlc2ZM0dl\nVatWVVlOTo65TtOmTVV2xhlnqOzRRx9VGW+OyG4///yzyi666KIYKgHKnsMOs597Wm+vsnqIEPBk\nGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIsBOsPgwYNVNnv2bJV1795dZXXr1jWvaQ2IWMaMGaOy\nDz/8UGUFBQWRrgdEtXTpUjO3tjq2ts/dtGmTyoozQGcNc3z00UcqGzZsmHk+AEArLCw082RD0yHi\nyTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWA3SG/Px8lb311luRMiBbvfnmm2Y+duxYlVm70rVs\n2TKl9YcMGaKyp556SmXsNgcAJePiiy9W2bhx42KoJLN4MgwAAIBg0QwDAAAgWDTDAAAACBbNMAAA\nAILFAB0AERHZsGGDmf/pT3/KcCUAgHTZvn175GPLlw+zLeTJMAAAAIJFMwwAAIBg0QwDAAAgWDTD\nAAAACBbNMAAAAILlvPeZW8y5zC2GMst77zK9Jvcu0iHT9y73LdKBz9zsVr16dTO3trbftWuXyqpU\nqZL2mjIl6r3Lk2EAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAsBuiQdRjmQLZigA7ZiM9cZCsG6AAA\nAIAi0AwDAAAgWDTDAAAACBbNMAAAAIKV0QE6AAAAoDThyTAAAACCRTMMAACAYNEMAwAAIFg0wwAA\nAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgW\nzTAAAACCRTMMAACAYNEMlxLOuUuccx8457Y757Y65z5zzl0Yd11AMs65C5xz85xzu5xzm51zzznn\nasVdF3AwzrnznXPe+GtL3LUBB8O9W3LKx10ARJxz14vI2H1/DZfEf6S0FJHKcdYFJOOc+7WIvC0i\nb4nIZSJytIjcJyLvOudae+93x1kfEMHNIrJgv78viKsQoJi4d9OMZjhmzrmGIjJKRG733o/a7396\nK5aCgGjuFpFvReRS732BiIhzbpmIfCoivUTksRhrA6JY6r3/OO4igEPAvZtmfE0ifj1FpFBEHo+7\nEKAYzhaRd/6nERYR8d4vEJFNItIltqoAACgmmuH45YrIMhG5wjn3jXOuwDm30jnXP+7CgIPYKyJ7\njHy3iJyS4VqAQzHRObfXObfJOTfJOVc/7oKAiLh308x57+OuIWj7/mi5riSaiMEi8o2IXC4ifUXk\nFu/96BjLA0zOuU9FxHvvz9ovayAi/xGRfO/94bEVBxyEc66ViFwlIu+LyFYRaSWJz958EWnlvd8Y\nY3lAUty7JYdmOGbOueUicpKIXOa9f3m//A1J3Oh1PP+SUMo4564SkedF5H4ReUREaojIEyJyriSa\n4UoxlgcUi3PuNEl83/0B7/3QuOsBouLeTQ++JhG/Tft+feeA/G0RqSUidTJbDlA07/1ESbw94jYR\n2SAiX4nI9yIyS0TWxVgaUGze+4UislxEzoi7FqA4uHfTg2Y4fl8myd2+XwszVQhQHN77O0XkGBFp\nLok/wegmiT/lmBdrYcChcSLCn8IhG3HvpohmOH7T9/3a7oC8nYh8571fn+F6gMi89zu89//lvd/g\nnGsvIk2FN6MgyzjnTheRX4nIJ3HXAhQH92568J7h+M0Skbki8g/n3DEi8m8R6SoiF4tIjzgLA5LZ\nN8iRJyIL90W5InK7iPzVez8/tsKAIjjnJkpi0HOhiGyRxGzGIEl8zWdMjKUBB8W9W3IYoCsFnHNH\niMgISTTBR0niVWsPeO8nxVoYkIRzrpmI/EMSr1E7XESWisgY7/2EWAsDiuCcGyQi3USkgSR2+Vwv\nIm+IyN3ee77vjlKLe7fk0AwDAAAgWHxnGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABCujr1ZzzjGt\nh5R5713RR6UX9y7SIdP3Lvct0oHPXGSrqPcuT4YBAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYB\nAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABItmGAAAAMGiGQYAAECwaIYBAAAQ\nLJphAAAABItmGAAAAMEqH3cBocjLy1PZgAEDVNa2bVuVee9VtmLFCnOdyZMnq2zcuHEqW7t2rXk+\nAABASHgyDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAguWs4awSW8y5zC0Wk379+pn5yJEjVZaTk1PS\n5YiIyNy5c1XWvXt3la1bty4T5aTMe+8yvWYI9y5KXqbvXe5bpAOfucXzzDPPqOzqq69W2cyZM83z\np02bprL58+erbM2aNZHq2bNnj5nv3bs30vnZLOq9y5NhAAAABItmGAAAAMGiGQYAAECwaIYBAAAQ\nLHagS0GHDh1U9tBDD5nHWsNyixYtUtnAgQNV9uWXX0auqVevXiq75557VDZo0CCV3XzzzZHXQXar\nUqWKygYPHmweO3ToUJVZg7fDhw9XWYsWLVTWqVOnKCUCQFZatmyZygoLC1Vm9RAHyw/VhAkTzPz6\n669XWUFBQVrXzhY8GQYAAECwaIYBAAAQLJphAAAABItmGAAAAMFiB7qIOnbsqLIXXnhBZdZgkojI\njBkzVGbtVrdhw4ZDqO5/Oac3W7GG6i6++GKV/f73v09p7UxhN6TU1a9fX2XffvuteWzr1q1VtnDh\nQpVZA3Q33XSTypo0aWKuk+q9nw3YgQ77q1WrlsoaN25sHluxYkWVdevWTWUTJ05UWbIdyD766KOi\nShQRPnPTweoh2rVrF/n8M844Q2XW53ilSpVUduSRR5rXvOiii1Rm7VibzdiBDgAAACgCzTAAAACC\nRTMMAACAYNEMAwAAIFjsQGcoX17/32Lt4mYNyy1evNi8prXTyw8//HAI1R2cNRA5fvx4lU2fPj3t\nayN7NGzYMO3XzM/PV5k1uHHyySeb54cwQIcwnHLKKSr7wx/+oLKePXuqrE6dOuY1ow679+jRI9Jx\nIiLlypWLfCxS8/rrr0fKUpWXl6eymTNnmsdecsklKitrA3RR8WQYAAAAwaIZBgAAQLBohgEAABAs\nmmEAAAAEi2YYAAAAweJtEoY+ffqorFWrVirbvXu3yq699lrzmiXx5ohUbNq0Ke4SEKNzzjkn7dd8\n5ZVXVGa9heX00083zw91ihnZoWXLlmY+YMAAlbVp00ZltWvXTntNlm3btqlszpw5GVkbmVWjRg2V\n3X333SorKCgwz0/2lokQ8WQYAAAAwaIZBgAAQLBohgEAABAsmmEAAAAEiwE6w0033RTpuL59+6rs\n888/T3c5QEqsLVcvu+wylRUWFprnJxu+AIrL2upeRKRixYoq2759e0mXIyL2QOeECRNU1qhRI/P8\nww8/PO01Wb766iuVDR06VGXWcPS8efNKpCakplq1amaem5urspycHJUNGTJEZdb9/Oyzz5rrvPfe\ne0VUGA6eDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGAxQJeC7777Lu4SgCLVqlVLZWeccYbK/vOf\n/5jnL168ONI6+fn5Ktu7d6/KGjduHOl6KHus3bFERC699FKVTZs2TWXDhg2LvFbz5s1Vdscdd6jM\nGiatUKGCypxz5jre+8g1RWH9c4uI/PGPf1TZrl270ro20qNq1aoqGzFihMqse08ktd0KP/nkE5U9\n8MADh3y9UPBkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABCvoATprwEJE5KSTTlLZtm3bVPb111+n\nvSYgLitWrEjp/JUrV6pszZo1KmvZsmVK6yA7HHHEESq7+uqrzWPr16+vsmbNmqnMGkxq0qSJec0O\nHToUVWKxJBugs1i7wD333HMqe/nll1XGbnHZ77zzzlNZ//79M7K29TOSbHdR/C+eDAMAACBYNMMA\nAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYQb9Nonx5+x+/XLlyKtu5c6fK2I4Z2eDCCy+MdNzIkSNT\nWsf6ebJ+lurUqWOeb719YOvWrSnVhPjUqFFDZVWqVDGPjbql8YABA1RWEtskL1iwQGUvvfSSeeys\nWbNUtn37dpV9//33h1wPsktubm5K52/cuFFl48aNU9lhh+nnmXfeeafKrK2gRUR69+6tsp9++ilK\niWUOT4YBAAAQLJphAAAABItmGAAAAMGiGQYAAECwgh6gi9vRRx+tso4dO6rstttui3zNVatWqaxh\nw4YqW79+vcqmTp2qsgkTJpjr5OfnR64J8Tr33HNVtmHDBpV9+OGHKa1jDZnOnDlTZX379jXPP/LI\nI1XGAF32sj6LfvjhB/NYa9guU4YPH66yRx55RGWbN2/ORDkoA+655x6V/etf/1LZjh07zPPff/99\nle3Zs0dl1vDolClTVPbuu++a64wfP15lvXr1UtmWLVvM88sSngwDAAAgWDTDAAAACBbNMAAAAIJF\nMwwAAIBgMUAXkTXgcfrpp6vss88+M89v3LixymbPnq2y+vXrq2zXrl0q++KLL8x1rKEVK+vRo4fK\n2rRpo7J27dqZ61x22WVmjnhZO3xdcsklKrOGMZINc6QihMELRJdskKdJkyaHfM0PPvjAzKdNm6ay\nSZMmqczacauwsPCQ6wEKCgpUNmPGjLSvY+2yuGTJEpX16dPHPH/69Okqmzt3rsrGjh17CNVlF54M\nAwAAIFhE084QAAAIQElEQVQ0wwAAAAgWzTAAAACCRTMMAACAYAU9QJdsR6Gff/5ZZdbuWFZ24okn\nmtecM2eOyo477jiVWQMm/fv3V9ny5cvNdaJ69dVXVWZ9mb5p06YprYPMqly5ssoaNGigsjVr1mSi\nHPNnKRnr5ylTdSIzBg0aZObWzpvWMLHl/PPPT6UkoMyzfr8XEXnxxRdVZv2MvvTSSypLtptktuLJ\nMAAAAIJFMwwAAIBg0QwDAAAgWDTDAAAACFbQA3TWzmwiIuvWrVOZNdxz5ZVXquzkk082r2kNy1k7\n0HXp0kVlJbEzmLX2+PHjVXbxxRenfW3ELycnR2WtW7c2j/3ll19UZg2fVqpUSWXWDknJjBs3TmUX\nXnihyvLz8yNfE6XL9u3bzdwa5OnevbvK6tWrp7L169eb15wyZYrK7r77bpUlG6QGyrrRo0errFu3\nbiq77rrrVHb//feXSE1x4ckwAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAguWKM+2d8mLO\nZW6xFIwYMUJld9xxR0rXtN7UcMstt6hs586dKa2TikmTJqmsffv25rEtW7ZU2erVq9Nek8V77zKy\n0H6y5d499thjVbZx48aUrllQUKAy660A1hsqrO2hi8N6u8qMGTNSumacMn3vZst9a7Gm2h9//HGV\nVatWzTzf+r1t/vz5KuvUqZPKfvrppyglBoPP3LKpYsWKKvvoo49UtnjxYpX16NGjRGpKt6j3Lk+G\nAQAAECyaYQAAAASLZhgAAADBohkGAABAsBigM1SvXl1ln3/+ucrq168f+Zq33nqrykaNGlW8wkqY\nta1psuGU0047TWVff/112muyMMyRXLly5VQ2fPhwlQ0aNCgT5RTLZ599prKzzz5bZXv37s1EOSWC\nAbrUWJ+51nCyiMhFF10U6ZpfffWVyi6//HKVLVu2LNL1yiI+c+2twEXsoc6uXbuqbPfu3WmvqSQM\nHTpUZddff73KTj31VJVt2bKlRGpKBQN0AAAAQBFohgEAABAsmmEAAAAEi2YYAAAAwWKALqIOHTqo\n7MUXX1RZlSpVzPN37Nihstdff11l999/v8qWLFkSpcRiycvLU9mrr76qsuXLl5vnN2vWLO01RcUw\nR/FYQ3U1a9ZUWbJ717pXrKEjK7OGLN566y1zHWt3sPPOO888NlsxQJd+1pCliL1TobVDo2XBggUq\nu/HGG81jrcHPsobPXJGGDRua+b///W+VPffccyr7y1/+orINGzakXFe6WQN09957r8pOPPFEla1a\ntaokSkoJA3QAAABAEWiGAQAAECyaYQAAAASLZhgAAADBYoAuBe3atVPZgw8+aB7bvHnzSNfctWuX\nynr37q2y1atXm+dbX2DPzc1V2ejRo1Vm7bz3wgsvmOv06NHDzDOBYY7s0bp1a5UlGzhigC79Qr5v\nL730UpVNmzbtkK9nfQ6LiEyYMOGQr5kt+MwVqVu3rplbO69aw8grVqxQWd++fc1rfvjhhyorKCgo\nqsRi69Kli8oeeughleXk5KjslFNOUdnPP/+cnsLSiAE6AAAAoAg0wwAAAAgWzTAAAACCRTMMAACA\nYDFAl2bJdjjq2bOnyqwdaY466qi012Sxvoxv7X53zz33ZKKcYmGYI3scc8wxKlu2bJl57N69e1X2\nq1/9SmWlcUgjKgbo0q9fv35m/uijj6Z1naefftrMrc/2sobP3OS6du2qssmTJ6d0TWtnOqtXe+WV\nV1TWuXPnyOvUqFFDZdaw3H333aeyu+66K/I6cWKADgAAACgCzTAAAACCRTMMAACAYNEMAwAAIFg0\nwwAAAAgWb5OIkTXJaU1GW9OqLVq0iLzOmjVrVPb444+rbMSIEZGvGScmm7Obte2yiMg555yjMmsL\n1HXr1qW9pkzhbRLRWdvdDxo0SGW/+c1vzPPT/XvbjTfeaObjxo1L6zqlEZ+5yZUrV05l7du3V9nA\ngQNVlup2887pfy2p3vfjx49X2ZAhQ1T2ww8/pLROpvA2CQAAAKAINMMAAAAIFs0wAAAAgkUzDAAA\ngGAxQIeswzBHdhswYICZP/zwwyq79NJLVWZtQZotQh+gy8vLM/PrrrtOZdYQkrVVrDVEJBJ9kGj4\n8OEqW7hwocpeffXVSNcri/jMTd1hh+lnj2eeeaZ5rDU0f+6556rs7LPPVtmePXtUNmXKFHOd0aNH\nq8y69wsLC83zswEDdAAAAEARaIYBAAAQLJphAAAABItmGAAAAMFigA5Zh2GO7HbWWWeZ+ccff6yy\n9957T2UXXHBBukvKmJAG6Hr37q2yZLtcWrtxWrZs2aKyefPmmcd+8cUXKnv55ZdVtnjxYpVl88BQ\nSeAzF9mKAToAAACgCDTDAAAACBbNMAAAAIJFMwwAAIBgMUCHrMMwB7JVSAN01o5ZHTp0MI+dOXNm\npGtu3LhRZStXrixeYSg2PnORrRigAwAAAIpAMwwAAIBg0QwDAAAgWDTDAAAACBbNMAAAAILF2ySQ\ndZhsRrYK6W0SKDv4zEW24m0SAAAAQBFohgEAABAsmmEAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAs\nmmEAAAAEi2YYAAAAwaIZBgAAQLAyugMdAAAAUJrwZBgAAADBohkGAABAsGiGAQAAECyaYQAAAASL\nZhgAAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgAAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgA\nAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgAAADBohkGAABAsGiGAQAAECyaYQAAAASLZhgAAADB\nohkGAABAsGiGAQAAEKz/BgZy4aPO2hNlAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plots(x_imgs[:8], titles=preds[:8])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining Logistic Regression Ourselves"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Above, we used pytorch's `nn.Linear` to create a linear layer. This is defined by a matrix multiplication and then an addition (these are also called `affine transformations`). Let's try defining this ourselves.\n",
"\n",
"Just as Numpy has `np.matmul` for matrix multiplication (in Python 3, this is equivalent to the `@` operator), PyTorch has `torch.matmul`. \n",
"\n",
"Our PyTorch class needs two things: constructor (says what the parameters are) and a forward method (how to calculate a prediction using those parameters) The method `forward` describes how the neural net converts inputs to outputs.\n",
"\n",
"In PyTorch, the optimizer knows to try to optimize any attribute of type **Parameter**."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_weights(*dims): return nn.Parameter(torch.randn(dims)/dims[0])\n",
"def softmax(x): return torch.exp(x)/(torch.exp(x).sum(dim=1)[:,None])\n",
"\n",
"class LogReg(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.l1_w = get_weights(28*28, 10) # Layer 1 weights\n",
" self.l1_b = get_weights(10) # Layer 1 bias\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = (x @ self.l1_w) + self.l1_b # Linear Layer\n",
" x = torch.log(softmax(x)) # Non-linear (LogSoftmax) Layer\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create our neural net and the optimizer. (We will use the same loss and metrics from above)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net2 = LogReg().cuda()\n",
"opt=optim.Adam(net2.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb28b90dfa8d4d2dad99665e5e3fd9a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.32209 0.28399 0.92088] \n",
"\n"
]
}
],
"source": [
"fit(net2, md, n_epochs=1, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dl = iter(md.trn_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xmb,ymb = next(dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
" ... ⋱ ... \n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
"-0.4245 -0.4245 -0.4245 ... -0.4245 -0.4245 -0.4245\n",
"[torch.cuda.FloatTensor of size 64x784 (GPU 0)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vxmb = Variable(xmb.cuda())\n",
"vxmb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Variable containing:\n",
"\n",
"Columns 0 to 5 \n",
" 1.6740e-03 1.0416e-05 2.5454e-05 1.9119e-02 6.5026e-05 9.7470e-01\n",
" 3.4048e-02 1.8530e-04 6.6637e-01 3.5073e-02 1.5283e-01 6.4995e-05\n",
" 3.0505e-08 4.3947e-08 1.0115e-05 2.0978e-04 9.9374e-01 6.3731e-05\n",
"\n",
"Columns 6 to 9 \n",
" 2.1126e-06 1.7638e-04 3.9351e-03 2.9154e-04\n",
" 1.1891e-03 3.2172e-02 1.4597e-02 6.3474e-02\n",
" 8.9568e-06 9.7507e-06 7.8676e-04 5.1684e-03\n",
"[torch.cuda.FloatTensor of size 3x10 (GPU 0)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds = net2(vxmb).exp(); preds[:3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
" 5\n",
" 2\n",
" 4\n",
" 1\n",
" 8\n",
" 8\n",
" 3\n",
" 4\n",
" 2\n",
" 7\n",
" 1\n",
" 1\n",
" 7\n",
" 9\n",
" 3\n",
" 4\n",
" 3\n",
" 0\n",
" 1\n",
" 4\n",
" 3\n",
" 3\n",
" 6\n",
" 8\n",
" 9\n",
" 0\n",
" 9\n",
" 5\n",
" 2\n",
" 4\n",
" 7\n",
" 9\n",
" 0\n",
" 2\n",
" 5\n",
" 2\n",
" 9\n",
" 9\n",
" 6\n",
" 9\n",
" 4\n",
" 2\n",
" 2\n",
" 7\n",
" 1\n",
" 5\n",
" 5\n",
" 9\n",
" 0\n",
" 7\n",
" 4\n",
" 0\n",
" 9\n",
" 1\n",
" 1\n",
" 4\n",
" 7\n",
" 0\n",
" 9\n",
" 0\n",
" 4\n",
" 4\n",
" 0\n",
" 7\n",
"[torch.cuda.LongTensor of size 64 (GPU 0)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds = preds.data.max(1)[1]; preds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at our predictions on the first eight images:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAAF0CAYAAADGqzQSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmUFeXV7/H9CLSMiqiMCggYvKIC4mwncQChhQUSMRHF\naDMoiBpxeGVSUeSiiRFQFJOgOIHK7IBTEByQV8WAckEQMEFQJgWRUemmn/tH910vl70PXd3n9Kk+\n/Xw/a7mIv1TVs02qD9vi7Hqc914AAACAEB0WdwEAAABAXGiGAQAAECyaYQAAAASLZhgAAADBohkG\nAABAsGiGAQAAECya4Zg55zo65+Y55zY5535xzn3rnJvqnDs57tqAQ3HOne+ce8c5t8U5t8M5t9g5\n1zvuuoAonHOXOuc+cM7tKrp/P3POXRR3XcChOOcudM4tcM7tdc5tc84975yrF3ddmY5mOH51RORf\nInKTiFwiIkNEpJWIfOycaxJnYUAizrnTRGSuiFQRkX4icrmILBKRp5xzA+KsDSiOc+4GEXlFCj97\nu4vIFSIyTUSqx1kXcCjOuV+LyDsisl0KP3P/JCK/EZF3nXOHx1lbpnNsulH+OOdaishKEbnDe//X\nuOsBDuac+98icoeI1PHe7zog/1hEvPf+3NiKAw7BOddURFaIyBDv/dh4qwGic87NFZGmInKS9z6/\nKDtTRD4VkYHe+ydiLC+j8WS4fNpa9GterFUAiWVJ4f2596B8u/C5gvKtt4gUiMiTcRcClNA5IvLP\n/9cIi4h47xdJYc/QPbaqKgB+0yonnHOVnHNZzrkTReRvIrJJRF6KuSwgkWeKfn3UOdfQOVfbOddP\nRC4WkTHxlQUUK1sK/+TtSufc1865fOfcGufcwLgLA4qxX0T2GfkvInJKmmupUPiaRDnhnPtMRNoV\n/e0aEenqvV8RY0nAIRX98dwsEWlUFOWJyADv/VPxVQUcmnNupYg0lMIGYqiIfC2F3xnuLyK3eu/H\nxVgekJBz7lMp/Bra2QdkTUTkPyKS573ne8OlRDNcTjjn/peIHCEizaTwu5j1RCTbe782zroAS9Gf\nYLwrhd+9fEwKvy7RTUQGiMh13vvJMZYHJOScWyUiJ4rI5d77mQfkb4pIWxFp4PmNEeWQc+5qEXlB\nREaJyKNSOID/dxE5Twqb4WoxlpfRaIbLIedcbRFZKyIvee/7x1wOoDjnponI6VI4yJF3QD5ZRDqK\nSF3vfUFc9QGJOOf+Wwq/e3mE937nAfkgEXlERBp57zfEVR9wKM65kVL4wKyqiHgReVlEaojIKd77\nZnHWlsn4znA55L3fLoVflWgRdy1AAqeKyBcHNsJFPhWRo0WkbvpLAiJZniB3Rb/yL3Eot7z3d4vI\nMSJymhT+KUZPKfyTjgWxFpbhaIbLoaIXaJ8khd9lA8qjTSLSxjmXdVB+toj8LCLb0l8SEMmsol87\nHpR3FJFvvfeb0lwPUCLe+93e+//jvd/snOskhf0Cb0dJQuW4Cwidc26WiCwWkaUiskNEfiUig0Qk\nX0R4xzDKq/FSuEnBa865J6TwO8NdRaSniIzx3lsTz0B58IaIzBeRvznnjhGRf4tIDync9Cg3zsKA\nQ3HOtRWRHCnsGUQK34xyp4j82Xu/MLbCKgC+Mxwz59xdIvJ7EWkuhe9uXS8i74nIaIbnUJ4553JE\n5C4p3DGxqhT+ScbfReRv3vv9cdYGHIpz7ggRGS2FTfBRUviqtQe991NiLQw4BOdcKyl89eopInK4\nFA0we+8nxVpYBUAzDAAAgGDxnWEAAAAEi2YYAAAAwaIZBgAAQLBohgEAABCstL5azTnHtB6S5r13\nxR+VWty7SIV037vct0gFPnORqaLeuzwZBgAAQLBohgEAABAsmmEAAAAEi2YYAAAAwaIZBgAAQLBo\nhgEAABAsmmEAAAAEi2YYAAAAwaIZBgAAQLDSugMdAADlSc2aNVXWp08flXXr1s08v2vXrirbtWtX\n8oUBSBueDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYvE0CABCsa6+9VmVjxoyJfH6r\nVq1U9sknnyRVE4D04skwAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFgN0SWjdurXKBg0aZB7bvHlz\nlVWvXl1lQ4cOVdmRRx6psjfffNNcZ+fOnWYOAKG77rrrVDZ27FiV5eXlqezhhx82r7l48eKk6wIQ\nL54MAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYDnvffoWcy59i6VYzZo1VbZu3TqV1a5dOx3lyHff\nfWfm1gDf9OnTy7qctPLeu3Svmcn3rsW6T7t3724e27ZtW5VlZ2erzPoZ2bZtm8rq169vrrNp0yaV\nPfPMMyr7xz/+obL9+/eb1yxv0n3vVrT7tiS6du2qslmzZqlsz549KrvnnntUVpJd6SoaPnORqaLe\nuzwZBgAAQLBohgEAABAsmmEAAAAEi2YYAAAAwWKALqJatWqp7I033lDZ1q1bzfOXLFmiMmswqUmT\nJio7/vjjVVatWjVznc2bN6vs3HPPjXRcpmCYo2SOO+44lc2ePVtl1v2YyI4dO1Rm3eNVqlRRmfWz\nJCJSt25dldWrV09lV111lco++OADlW3cuNFcJ04M0KVeVlaWmU+aNEllPXv2VNm8efNU1r59++QL\nq0D4zEWmYoAOAAAAKAbNMAAAAIJFMwwAAIBg0QwDAAAgWAzQZYBjjjlGZXfeead5rJXn5uaq7Nln\nn02+sJgwzFEyixcvVlnr1q1VNnfuXPP822+/XWU//PCDyqwd5Eri2GOPVdmbb76pspYtW6ps8ODB\nKnv88ceTqqcsMECXesOGDTPzkSNHquyFF15QWe/evVWWn5+ffGEVCJ+5yWvQoIHKbrzxRvNYK8/L\ny1OZtQvuqFGjVGb9HiAisn79ejOvSBigAwAAAIpBMwwAAIBg0QwDAAAgWDTDAAAACBbNMAAAAILF\n2yQyVNeuXc3c2mb30UcfVdmtt96a8prShcnmxKyJ5e+++05lU6dOVdnVV19tXnP//v3JF1ZKkydP\nVtmVV16psnbt2qns888/L5OaksHbJJJzxhlnqGzBggXmsWvXrlVZq1atVBbn/Z0p+MwtmWbNmqls\nwoQJKuvQoUM6ypFffvnFzM8//3yVJXrzRKbibRIAAABAMWiGAQAAECyaYQAAAASLZhgAAADBqhx3\nASjeUUcdpbKhQ4dGPr9hw4apLAflWJs2bVTmnJ4f2LBhg8riHiQ655xzVNazZ0+VzZ8/X2XWP3d5\nHKBDdIcdpp/VWNtuZ2Vlmee/9tprKov7HkfF06hRI5UtW7ZMZZUr63ZrzJgx5jUfe+yxSOucdNJJ\nKvvLX/6istq1a5vrWIPU1ufwDz/8YJ5fkfBkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABIsd6MqZ\n1q1bq2zatGkqa9GihXn+qlWrVGbtcrN+/fpSVFc+sBtSyRQUFKhsy5YtKjvrrLPM89etW5fSemrV\nqmXmCxcuVNnq1atVZu2UZ+34tHz58lJUV7bYgS66qLspJnLLLbeobPz48UnVFCo+cxMbN26cyvr3\n76+yfv36qey5555LeT0DBw5U2dixY81jK1WqpLKVK1eqzBqq27FjRymqSz92oAMAAACKQTMMAACA\nYNEMAwAAIFg0wwAAAAgWA3Qxuvbaa1V2//33q+z4449X2d69e81rdunSRWXWjl2ZjGGOkhkxYoTK\n7r77bpV99dVX5vkdO3ZUWTIDmO+8846Z//a3v1VZu3btVGbt7pQpGKCLLjc3V2VPPfWUyubOnWue\nn5OTozJ2oCsdPnNFjjjiCDO3hnwnTZqkMmv3xHRJ9Nl+4oknRjrf2inv9ttvT6qmdGGADgAAACgG\nzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWb5NIsZo1a5r5HXfcobLhw4er7LDD9L+fbNu2TWXZ\n2dnmOtZWihUNk80lU7VqVZU9++yzKuvRo4d5/po1a1R2wQUXqGzjxo0qe+KJJ1R2/fXXm+vceeed\nKrOmmDMZb5OwVa5cWWUrVqxQWZMmTVR2wgknmNcsydbNODQ+cxNvV//xxx+rrEOHDip79913U15T\nVN27dzfzmTNnqszqCbdv364y600UW7duLUV1ZYu3SQAAAADFoBkGAABAsGiGAQAAECyaYQAAAARL\nTy0gKc8884yZ/+53v4t0/vTp01U2duxYlYUwKIfU+Pnnn1XWt29fldWtW9c839om+f3331fZtGnT\nVNarVy+VzZgxw1ynog3LITpreLN58+YqGzBggMriHpTr1KmTyrp27aqyt956S2XW1uTWzyvi17Zt\n28jHLlmypAwrKbk33njDzK3haOvnzrond+/enXxh5QhPhgEAABAsmmEAAAAEi2YYAAAAwaIZBgAA\nQLAYoEsx68vnJTFhwgSVLVy4MKlrAgfbuXOnyrp162YeO2LECJXdeuutKhs8eHCktR977LFIxyEc\njRs3jnRcVlZWGVeS2HXXXWfm1i6L1q6P/fv3V5m1s9fs2bPNdXr37l1MhShLCxYsMPOCggKV/fOf\n/1RZly5dVGbt2lkWWrZsaebWfdqxY0eVVa9eXWUVbdCTJ8MAAAAIFs0wAAAAgkUzDAAAgGDRDAMA\nACBYDNClmLWjkIhI69atS32+NVT34IMPmudv2LAh0jrAwXbs2GHm99xzj8o6dOigspNPPjnSOu3b\ntzfzRAMqqPhatGgR6bh07bxZu3ZtlT3yyCPmsdYQUn5+vsqsoars7GyVWbs2ijBAF7fly5eb+euv\nv64yaxh5xYoVKrN2JRSxd+mcN2+eyho1aqQya1jO2sVWRKRBgwYqs+7dV155xTy/IuHJMAAAAIJF\nMwwAAIBg0QwDAAAgWDTDAAAACJbz3qdvMefSt1hMqlWrZuYvvPCCytq1a6eyqDsxbdq0ycxzc3NV\n9vbbb0e6Zqbw3rt0rxnCvZtITk6OymbNmqWyKlWqRLrevn37zPzGG29U2aRJkyJdM1Ok+97NlPt2\nzpw5Kmvbtq3KGjZsmI5yzB0WEw3QWZ/t48aNU9m6detUZg1QnXrqqeY6ce6+x2duYtbv+aNHj1bZ\nLbfcktQ627ZtU1mdOnWSuqbliiuuUJk10Jcpot67PBkGAABAsGiGAQAAECyaYQAAAASLZhgAAADB\nohkGAABAsNiOOcX27t1r5ldffbXKKlfW//Mn2hL3YPXr1zdza8r/tttuU9mTTz4ZaR3gwgsvVJn1\nFpru3burzJqAtrYvFbG3Hf/hhx9U9tprr5nnI3OdffbZKkv01pHyZsOGDSo77rjjVPb3v/9dZaef\nfrrKKtrbfyo66/d8620kU6dOVZnVFyRSr169SMfl5eWpzPr5EhE54YQTVLZnz57INVUkPBkGAABA\nsGiGAQAAECyaYQAAAASLZhgAAADBYjvmcua0005T2ZgxY1RmDTUlYm0D2rRp0xLVVZ6wNWjZsO49\nEZFFixapzBp2s4ZGLNZ2nyIiTz31lMqc0/9Xt2rVSmXWPV4esR2zzRou69Kli8rKYjtm6x6z7uW/\n/vWvSa1j/V77xBNPqGzo0KHm+Tt37kxq/WTwmZvZnn/+eTO3Bvg6deqksnfeeSflNaUL2zEDAAAA\nxaAZBgAAQLBohgEAABAsmmEAAAAEix3oIqpevbrKymKnlqVLl6qsR48eKnv66afN87t166ayxo0b\nq6xBgwYq27hxY5QSUUHVqlXLzK2dEqdPn17qdaZNm2bmTZo0UdlDDz2ksnbt2qksUwboEF3t2rVV\nZg0CvfDCC+b51n175ZVXqqxOnToqy8nJiVKiiIjs3r1bZQsWLFDZn//8Z5XNnz8/8jpAOjRv3jzu\nEmLBk2EAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAsBugM1hfIrYGIOXPmqGzZsmXmNa3htD59+qis\nSpUqKmvUqJHKWrRoYa5j+frrryPVg7C1adPGzDdt2qQy6+chWePHj1dZv379VDZw4ECVzZo1K+X1\nIH2WLFmisr59+6rM2jHLypK1Y8cOlSUa/HzggQdU9s0336S8JqC0du3aFXcJ5R5PhgEAABAsmmEA\nAAAEi2YYAAAAwaIZBgAAQLAYoDNcccUVKqtfv77KevfunfK1nXMq895HPt/6onz//v2TqglhsHYq\nFBH59NNP07L+vn37VPbjjz+q7Ne//rXKrF3Etm3blprCUOamTJmiMmvnzdWrV6usUqVK5jUT5Qeb\nPHmyytauXasyaxAZyAQffPCBmd9www0qq1u3blmXUy7xZBgAAADBohkGAABAsGiGAQAAECyaYQAA\nAASLZhgAAADB4m0ShqOPPjruEv4/M2bMUNnIkSPNY7ds2aIyaztd4GCJ3lqSnZ2tsiuvvFJl8+bN\nU1nNmjVVlpWVZa5z0kknqezMM89U2eOPP64y3hyR2X766SeVXXzxxTFUAlQ8hx1mP/e03l5l9RAh\n4MkwAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFgN0hqFDh6ps7ty5KuvVq5fKGjZsaF7TGhCxPPbY\nYyr78MMPVZafnx/pekBUK1asMHNrq2Nr+9ytW7eqrCQDdNYwx0cffaSyESNGmOcDALSCggIzTzQ0\nHSKeDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGAxQGfIy8tT2dtvvx0pAzLVW2+9Zebjx49XmbUr\nXZs2bZJaf9iwYSp7+umnVcZucwBQNi655BKVTZgwIYZK0osnwwAAAAgWzTAAAACCRTMMAACAYNEM\nAwAAIFgM0AEQEZHNmzeb+Z/+9Kc0VwIASJVdu3ZFPrZy5TDbQp4MAwAAIFg0wwAAAAgWzTAAAACC\nRTMMAACAYNEMAwAAIFjOe5++xZxL32KosLz3Lt1rcu8iFdJ973LfIhX4zM1stWvXNnNra/u9e/eq\nrEaNGimvKV2i3rs8GQYAAECwaIYBAAAQLJphAAAABItmGAAAAMFigA4Zh2EOZCoG6JCJ+MxFpmKA\nDgAAACgGzTAAAACCRTMMAACAYNEMAwAAIFhpHaADAAAAyhOeDAMAACBYNMMAAAAIFs0wAAAAgkUz\nDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAA\ngGDRDAMAACBYNMMAAAAIFs1wOeGcu9Q594Fzbpdzbodz7jPn3EVx1wVE5Zx7yznnnXMPxF0LkIhz\n7oKi+/Tgv7bHXRtwKNy7Zady3AVAxDl3g4iML/prpBT+S0obEakeZ11AVM65niLSOu46gBK4RUQW\nHfD3+XEVApQQ926K0QzHzDnXVETGisid3vuxB/xXb8dSEFBCzrnaIjJGRAaJyJSYywGiWuG9/zju\nIoBS4N5NMb4mEb/eIlIgIk/GXQhQSn8WkeXe+xfjLgQAgJKiGY5ftoisFJErnXNfO+fynXNrnHMD\n4y4MKI5zLltE/igiN8ZdC1BCk51z+51zW51zU5xzjeMuCIiIezfF+JpE/BoW/fUXERkqIl+LyBUi\nMt45V9l7Py7O4oBEnHNVRORvIvKw9/6ruOsBIvpJRP4qIu+LyA4RaSuFn73/7Zxr673fEmdxwCFw\n75YR572Pu4agOedWiciJInK5937mAfmbUnijN/D8n4RyyDk3XAq/5tPKe7+3KPMiMsp7PzzW4oAS\ncM6dLiKfisiD3LvIJNy7qcHXJOK3tejXfx6UvyMi9USkQXrLAYpX9Mdyw0TkbhE53DlXu2iQTg74\n+0rxVQhE571fLCKrROTMuGsBSoJ7NzVohuO3PEHuin4tSFchQAk0E5GqIvKCiPx4wF8iIncU/edT\n4ykNKBUnIvwpHDIR926SaIbjN6vo144H5R1F5Fvv/aY01wNE8bmIXGj8JVLYIF8oImviKQ0oGefc\nGSLyKxH5JO5agJLg3k0NBuji94aIzBeRvznnjhGRf4tIDxG5RERy4ywMSMR7v11E3js4d86JiHzj\nvVf/HVAeOOcmi8h/RGSxiGyXwtmMISLynYg8FmNpwCFx75YdmuGYee+9c+4yERktIveJyFFS+Kq1\nq733bGAAAKm1TER6isjNUrjL5yYRmSki93rvf4izMKAY3LtlhLdJAAAAIFh8ZxgAAADBohkGAABA\nsGiGAQAAECyaYQAAAAQrrW+TKNqqFUiK994Vf1Rqce8iFdJ973LfIhX4zEWminrv8mQYAAAAwaIZ\nBgAAQLBohgEAABAsmmEAAAAEi2YYAAAAwaIZBgAAQLBohgEAABAsmmEAAAAEi2YYAAAAwaIZBgAA\nQLBohgEAABAsmmEAAAAEi2YYAAAAwaIZBgAAQLBohgEAABCsynEXEIqcnByVDRo0SGUdOnRQmfde\nZatXrzbXmTp1qsomTJigsg0bNpjnAwAAhIQnwwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFjOGs4q\ns8WcS99iMRkwYICZjxkzRmVZWVllXY6IiMyfP19lvXr1UtnGjRvTUU7SvPcu3WuGcO+i7KX73uW+\nRSrwmVsyzz77rMquueYalc2ZM8c8f8aMGSpbuHChytavXx+pnn379pn5/v37I52fyaLeuzwZBgAA\nQLBohgEAABAsmmEAAAAEi2YYAAAAwWIHuiR07txZZQ8//LB5rDUst2TJEpUNHjxYZcuXL49cU58+\nfVR23333qWzIkCEqu+WWWyKvg8xWo0YNlQ0dOtQ8dvjw4SqzBm9HjhypstatW6usa9euUUoEgIy0\ncuVKlRUUFKjM6iEOlZfWpEmTzPyGG25QWX5+fkrXzhQ8GQYAAECwaIYBAAAQLJphAAAABItmGAAA\nAMFiB7qIunTporIXX3xRZdZgkojI7NmzVWbtVrd58+ZSVPc/nNObrVhDdZdcconKfv/73ye1drqw\nG1LyGjdurLJvvvnGPLZdu3YqW7x4scqsAbqbb75ZZS1btjTXSfbezwTsQIcD1atXT2UtWrQwj61a\ntarKevbsqbLJkyerLNEOZB999FFxJYoIn7mpYPUQHTt2jHz+mWeeqTLrc7xatWoqO/LII81rXnzx\nxSqzdqzNZOxABwAAABSDZhgAAADBohkGAABAsGiGAQAAECx2oDNUrqz/Z7F2cbOG5ZYuXWpe09rp\n5fvvvy9FdYdmDUROnDhRZbNmzUr52sgcTZs2Tfk18/LyVGYNbpx88snm+SEM0CEMp5xyisr+8Ic/\nqKx3794qa9CggXnNqMPuubm5kY4TEalUqVLkY5Gc119/PVKWrJycHJXNmTPHPPbSSy9VWUUboIuK\nJ8MAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYNMMAAAAIFm+TMPTr109lbdu2Vdkvv/yisuuuu868\nZlm8OSIZW7dujbsExOjcc89N+TVfeeUVlVlvYTnjjDPM80OdYkZmaNOmjZkPGjRIZe3bt1dZ/fr1\nU16TZefOnSqbN29eWtZGetWpU0dl9957r8ry8/PN8xO9ZSJEPBkGAABAsGiGAQAAECyaYQAAAASL\nZhgAAADBYoDOcPPNN0c6rn///ir7/PPPU10OkBRry9XLL79cZQUFBeb5iYYvgJKytroXEalatarK\ndu3aVdbliIg90Dlp0iSVNW/e3Dz/8MMPT3lNli+//FJlw4cPV5k1HL1gwYIyqQnJqVWrlplnZ2er\nLCsrS2XDhg1TmXU/P/fcc+Y67733XjEVhoMnwwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFgM0CXh\n22+/jbsEoFj16tVT2Zlnnqmy//znP+b5S5cujbROXl6eyvbv36+yFi1aRLoeKh5rdywRkcsuu0xl\nM2bMUNmIESMir3Xaaaep7K677lKZNUxapUoVlTnnzHW895FrisL65xYR+eMf/6iyvXv3pnRtpEbN\nmjVVNnr0aJVZ955IcrsVfvLJJyp78MEHS329UPBkGAAAAMGiGQYAAECwaIYBAAAQLJphAAAABCvo\nATprwEJE5MQTT1TZzp07VfbVV1+lvCYgLqtXr07q/DVr1qhs/fr1KmvTpk1S6yAzHHHEESq75ppr\nzGMbN26sslatWqnMGkxq2bKlec3OnTsXV2KJJBqgs1i7wD3//PMqmzlzpsrYLS7znX/++SobOHBg\nWta2fkYS7S6K/8GTYQAAAASLZhgAAADBohkGAABAsGiGAQAAECyaYQAAAAQr6LdJVK5s/+NXqlRJ\nZXv27FEZ2zEjE1x00UWRjhszZkxS61g/T9bPUoMGDczzrbcP7NixI6maEJ86deqorEaNGuaxUbc0\nHjRokMrKYpvkRYsWqezll182j33jjTdUtmvXLpV99913pa4HmSU7Ozup87ds2aKyCRMmqOyww/Tz\nzLvvvltl1lbQIiJ9+/ZV2Y8//hilxAqHJ8MAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYQQ/Qxe3o\no49WWZcuXVR2++23R77m2rVrVda0aVOVbdq0SWXTp09X2aRJk8x18vLyIteEeJ133nkq27x5s8o+\n/PDDpNaxhkznzJmjsv79+5vnH3nkkSpjgC5zWZ9F33//vXmsNWyXLiNHjlTZo48+qrJt27aloxxU\nAPfdd5/K/vWvf6ls9+7d5vnvv/++yvbt26cya3h02rRpKnv33XfNdSZOnKiyPn36qGz79u3m+RUJ\nT4YBAAAQLJphAAAABItmGAAAAMGiGQYAAECwGKCLyBrwOOOMM1T22Wefmee3aNFCZXPnzlVZ48aN\nVbZ3716VffHFF+Y61tCKleXm5qqsffv2KuvYsaO5zuWXX27miJe1w9ell16qMmsYI9EwRzJCGLxA\ndIkGeVq2bFnqa37wwQdmPmPGDJVNmTJFZdaOWwUFBaWuB8jPz1fZ7NmzU76OtcvismXLVNavXz/z\n/FmzZqls/vz5Khs/fnwpqsssPBkGAABAsGiGAQAAECyaYQAAAASLZhgAAADBCnqALtGOQj/99JPK\nrN2xrKxZs2bmNefNm6ey4447TmXWgMnAgQNVtmrVKnOdqF599VWVWV+mP+mkk5JaB+lVvXp1lTVp\n0kRl69dfjbV6AAAH4UlEQVSvT0c55s9SItbPU7rqRHoMGTLEzK2dN61hYssFF1yQTElAhWf9fi8i\n8tJLL6nM+hl9+eWXVZZoN8lMxZNhAAAABItmGAAAAMGiGQYAAECwaIYBAAAQrKAH6Kyd2URENm7c\nqDJruOeqq65S2cknn2xe0xqWs3ag6969u8rKYmcwa+2JEyeq7JJLLkn52ohfVlaWytq1a2ce+/PP\nP6vMGj6tVq2ayqwdkhKZMGGCyi666CKV5eXlRb4mypddu3aZuTXI06tXL5U1atRIZZs2bTKvOW3a\nNJXde++9Kks0SA1UdOPGjVNZz549VXb99derbNSoUWVSU1x4MgwAAIBg0QwDAAAgWDTDAAAACBbN\nMAAAAIJFMwwAAIBguZJMeye9mHPpWywJo0ePVtldd92V1DWtNzXceuutKtuzZ09S6yRjypQpKuvU\nqZN5bJs2bVS2bt26lNdk8d67tCx0gEy5d4899liVbdmyJalr5ufnq8x6K4D1hgpre+iSsN6uMnv2\n7KSuGad037uZct9arKn2J598UmW1atUyz7d+b1u4cKHKunbtqrIff/wxSonB4DO3YqpatarKPvro\nI5UtXbpUZbm5uWVSU6pFvXd5MgwAAIBg0QwDAAAgWDTDAAAACBbNMAAAAILFAJ2hdu3aKvv8889V\n1rhx48jXvO2221Q2duzYkhVWxqxtTRMNp5x++ukq++qrr1Jek4VhjsQqVaqkspEjR6psyJAh6Sin\nRD777DOVnXPOOSrbv39/OsopEwzQJcf6zLWGk0VELr744kjX/PLLL1V2xRVXqGzlypWRrlcR8Zlr\nbwUuYg919ujRQ2W//PJLymsqC8OHD1fZDTfcoLJTTz1VZdu3by+TmpLBAB0AAABQDJphAAAABItm\nGAAAAMGiGQYAAECwGKCLqHPnzip76aWXVFajRg3z/N27d6vs9ddfV9moUaNUtmzZsigllkhOTo7K\nXn31VZWtWrXKPL9Vq1YprykqhjlKxhqqq1u3rsoS3bvWvWINHVmZNWTx9ttvm+tYu4Odf/755rGZ\nigG61LOGLEXsnQqtHRotixYtUtlNN91kHmsNflY0fOaKNG3a1Mz//e9/q+z5559X2X/913+pbPPm\nzUnXlWrWAN3999+vsmbNmqls7dq1ZVFSUhigAwAAAIpBMwwAAIBg0QwDAAAgWDTDAAAACBYDdEno\n2LGjyh566CHz2NNOOy3SNffu3auyvn37qmzdunXm+dYX2LOzs1U2btw4lVk777344ovmOrm5uWae\nDgxzZI527dqpLNHAEQN0qRfyfXvZZZepbMaMGaW+nvU5LCIyadKkUl8zU/CZK9KwYUMzt3ZetYaR\nV69erbL+/fub1/zwww9Vlp+fX1yJJda9e3eVPfzwwyrLyspS2SmnnKKyn376KTWFpRADdAAAAEAx\naIYBAAAQLJphAAAABItmGAAAAMFigC7FEu1w1Lt3b5VZO9IcddRRKa/JYn0Z39r97r777ktHOSXC\nMEfmOOaYY1S2cuVK89j9+/er7Fe/+pXKyuOQRlQM0KXegAEDzPzxxx9P6TrPPPOMmVuf7RUNn7mJ\n9ejRQ2VTp05N6prWznRWr/bKK6+orFu3bpHXqVOnjsqsYbkHHnhAZffcc0/kdeLEAB0AAABQDJph\nAAAABItmGAAAAMGiGQYAAECwaIYBAAAQLN4mESNrktOajLamVVu3bh15nfXr16vsySefVNno0aMj\nXzNOTDZnNmvbZRGRc889V2XWFqgbN25MeU3pwtskorO2ux8yZIjKfvOb35jnp/r3tptuusnMJ0yY\nkNJ1yiM+cxOrVKmSyjp16qSywYMHqyzZ7ead0/+3JHvfT5w4UWXDhg1T2ffff5/UOunC2yQAAACA\nYtAMAwAAIFg0wwAAAAgWzTAAAACCxQAdMg7DHJlt0KBBZv7II4+o7LLLLlOZtQVppgh9gC4nJ8fM\nr7/+epVZQ0jWVrHWEJFI9EGikSNHqmzx4sUqe/XVVyNdryLiMzd5hx2mnz2eddZZ5rHW0Px5552n\nsnPOOUdl+/btU9m0adPMdcaNG6cy694vKCgwz88EDNABAAAAxaAZBgAAQLBohgEAABAsmmEAAAAE\niwE6ZByGOTLb2WefbeYff/yxyt577z2VXXjhhakuKW1CGqDr27evyhLtcmntxmnZvn27yhYsWGAe\n+8UXX6hs5syZKlu6dKnKMnlgqCzwmYtMxQAdAAAAUAyaYQAAAASLZhgAAADBohkGAABAsBigQ8Zh\nmAOZKqQBOmvHrM6dO5vHzpkzJ9I1t2zZorI1a9aUrDCUGJ+5yFQM0AEAAADFoBkGAABAsGiGAQAA\nECyaYQAAAASLZhgAAADB4m0SyDhMNiNThfQ2CVQcfOYiU/E2CQAAAKAYNMMAAAAIFs0wAAAAgkUz\nDAAAgGDRDAMAACBYNMMAAAAIFs0wAAAAgkUzDAAAgGDRDAMAACBYad2BDgAAAChPeDIMAACAYNEM\nAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYNEMAwAA\nIFg0wwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0\nwwAAAAgWzTAAAACCRTMMAACAYNEMAwAAIFg0wwAAAAjW/wWeKbMQcyi29AAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"preds = predict(net2, md.val_dl).argmax(1)\n",
"plots(x_imgs[:8], titles=preds[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.91910000000000003"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(preds == y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Aside about Broadcasting and Matrix Multiplication"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's dig in to what we were doing with `torch.matmul`: matrix multiplication. First, let's start with a simpler building block: **broadcasting**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Element-wise operations "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Broadcasting and element-wise operations are supported in the same way by both numpy and pytorch.\n",
"\n",
"Operators (+,-,\\*,/,>,<,==) are usually element-wise.\n",
"\n",
"Examples of element-wise operations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([10, 6, -4]), array([2, 8, 7]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.array([10, 6, -4])\n",
"b = np.array([2, 8, 7])\n",
"a,b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([12, 14, 3])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.66666666666666663"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(a < b).mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Broadcasting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The term **broadcasting** describes how arrays with different shapes are treated during arithmetic operations. The term broadcasting was first used by Numpy, although is now used in other libraries such as [Tensorflow](https://www.tensorflow.org/performance/xla/broadcasting) and Matlab; the rules can vary by library.\n",
"\n",
"From the [Numpy Documentation](https://docs.scipy.org/doc/numpy-1.10.0/user/basics.broadcasting.html):\n",
"\n",
" The term broadcasting describes how numpy treats arrays with \n",
" different shapes during arithmetic operations. Subject to certain \n",
" constraints, the smaller array is “broadcast” across the larger \n",
" array so that they have compatible shapes. Broadcasting provides a \n",
" means of vectorizing array operations so that looping occurs in C\n",
" instead of Python. It does this without making needless copies of \n",
" data and usually leads to efficient algorithm implementations.\n",
" \n",
"In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors.\n",
"\n",
"*This section was adapted from [Chapter 4](http://nbviewer.jupyter.org/github/fastai/numerical-linear-algebra/blob/master/nbs/4.%20Compressed%20Sensing%20of%20CT%20Scans%20with%20Robust%20Regression.ipynb#4.-Compressed-Sensing-of-CT-Scans-with-Robust-Regression) of the fast.ai [Computational Linear Algebra](https://github.com/fastai/numerical-linear-algebra) course.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting with a scalar"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([10, 6, -4])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ True, True, False], dtype=bool)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a > 0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How are we able to do a > 0? 0 is being **broadcast** to have the same dimensions as a.\n",
"\n",
"Remember above when we normalized our dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar)? We were using broadcasting!\n",
"\n",
"Other examples of broadcasting with a scalar:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([11, 7, -3])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a + 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 2, 3],\n",
" [4, 5, 6],\n",
" [7, 8, 9]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = np.array([[1, 2, 3], [4,5,6], [7,8,9]]); m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 2, 4, 6],\n",
" [ 8, 10, 12],\n",
" [14, 16, 18]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"2*m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting a vector to a matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also broadcast a vector to a matrix:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([10, 20, 30])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = np.array([10,20,30]); c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 22, 33],\n",
" [14, 25, 36],\n",
" [17, 28, 39]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m + c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 22, 33],\n",
" [14, 25, 36],\n",
" [17, 28, 39]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c + m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although numpy does this automatically, you can also use the `broadcast_to` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3,)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 10, 10],\n",
" [20, 20, 20],\n",
" [30, 30, 30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.broadcast_to(c[:,None], m.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 20, 30],\n",
" [10, 20, 30],\n",
" [10, 20, 30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.broadcast_to(np.expand_dims(c,0), (3,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3,)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 3)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.expand_dims(c,0).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The numpy `expand_dims` method lets us convert the 1-dimensional array `c` into a 2-dimensional array (although one of those dimensions has value 1)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 3)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.expand_dims(c,0).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 22, 33],\n",
" [14, 25, 36],\n",
" [17, 28, 39]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m + np.expand_dims(c,0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10],\n",
" [20],\n",
" [30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.expand_dims(c,1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 1)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c[:, None].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[11, 12, 13],\n",
" [24, 25, 26],\n",
" [37, 38, 39]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m + np.expand_dims(c,1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 10, 10],\n",
" [20, 20, 20],\n",
" [30, 30, 30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.broadcast_to(np.expand_dims(c,1), (3,3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Broadcasting Rules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 20, 30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c[None]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10],\n",
" [20],\n",
" [30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c[:,None]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[False, True, True],\n",
" [False, False, True],\n",
" [False, False, False]], dtype=bool)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c[None] > c[:,None]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0],\n",
" [1],\n",
" [2],\n",
" [3],\n",
" [4]]), array([[0, 1, 2, 3, 4]]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xg,yg = np.ogrid[0:5, 0:5]; xg,yg"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 1, 2, 3, 4],\n",
" [1, 2, 3, 4, 5],\n",
" [2, 3, 4, 5, 6],\n",
" [3, 4, 5, 6, 7],\n",
" [4, 5, 6, 7, 8]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xg+yg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When operating on two arrays, Numpy/PyTorch compares their shapes element-wise. It starts with the **trailing dimensions**, and works its way forward. Two dimensions are **compatible** when\n",
"\n",
"- they are equal, or\n",
"- one of them is 1\n",
"\n",
"Arrays do not need to have the same number of dimensions. For example, if you have a `256*256*3` array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:\n",
"\n",
" Image (3d array): 256 x 256 x 3\n",
" Scale (1d array): 3\n",
" Result (3d array): 256 x 256 x 3\n",
"\n",
"The [numpy documentation](https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html#general-broadcasting-rules) includes several examples of what dimensions can and can not be broadcast together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matrix Multiplication"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to use broadcasting to define matrix multiplication."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[1, 2, 3],\n",
" [4, 5, 6],\n",
" [7, 8, 9]]), array([10, 20, 30]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m, c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([140, 320, 500])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m @ c # np.matmul(m, c)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We get the same answer using `torch.matmul`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
" 140\n",
" 320\n",
" 500\n",
"[torch.LongTensor of size 3]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"T(m) @ T(c)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following is **NOT** matrix multiplication. What is it?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[1, 2, 3],\n",
" [4, 5, 6],\n",
" [7, 8, 9]]), array([10, 20, 30]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m,c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 10, 40, 90],\n",
" [ 40, 100, 180],\n",
" [ 70, 160, 270]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m * c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([140, 320, 500])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(m * c).sum(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([10, 20, 30])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 20, 30],\n",
" [10, 20, 30],\n",
" [10, 20, 30]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.broadcast_to(c, (3,3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From a machine learning perspective, matrix multiplication is a way of creating features by saying how much we want to weight each input column. **Different features are different weighted averages of the input columns**. \n",
"\n",
"The website [matrixmultiplication.xyz](http://matrixmultiplication.xyz/) provides a nice visualization of matrix multiplcation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[10, 40],\n",
" [20, 0],\n",
" [30, -5]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = np.array([[10,40],[20,0],[30,-5]]); n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 2, 3],\n",
" [4, 5, 6],\n",
" [7, 8, 9]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[140, 25],\n",
" [320, 130],\n",
" [500, 235]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m @ n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([140, 320, 500])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(m * n[:,0]).sum(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 25, 130, 235])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(m * n[:,1]).sum(axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Writing Our Own Training Loop"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a reminder, this is what we did above to write our own logistic regression class (as a pytorch neural net):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4fe9bd617ca418889de7ab03703927b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"A Jupyter Widget"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0.31102 0.28004 0.92406] \n",
"\n"
]
}
],
"source": [
"# Our code from above\n",
"class LogReg(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.l1_w = get_weights(28*28, 10) # Layer 1 weights\n",
" self.l1_b = get_weights(10) # Layer 1 bias\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" x = x @ self.l1_w + self.l1_b \n",
" return torch.log(softmax(x))\n",
"\n",
"net2 = LogReg().cuda()\n",
"opt=optim.Adam(net2.parameters())\n",
"\n",
"fit(net2, md, n_epochs=1, crit=loss, opt=opt, metrics=metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Above, we are using the fastai method `fit` to train our model. Now we will try writing the training loop ourselves.\n",
"\n",
"**Review question:** What does it mean to train a model?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the LogReg class we created, as well as the same loss function, learning rate, and optimizer as before:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"net2 = LogReg().cuda()\n",
"loss=nn.NLLLoss()\n",
"learning_rate = 1e-3\n",
"optimizer=optim.Adam(net2.parameters(), lr=learning_rate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"md is the ImageClassifierData object we created above. We want an iterable version of our training data (**question**: what does it mean for something to be iterable?):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dl = iter(md.trn_dl) # Data loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we will do a **forward pass**, which means computing the predicted y by passing x to the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xt, yt = next(dl)\n",
"y_pred = net2(Variable(xt).cuda())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check the loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
" 2.3162\n",
"[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
"\n"
]
}
],
"source": [
"l = loss(y_pred, Variable(yt).cuda())\n",
"print(l)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We may also be interested in the accuracy. We don't expect our first predictions to be very good, because the weights of our network were initialized to random values. Our goal is to see the loss decrease (and the accuracy increase) as we train the network:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.078125"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(to_np(y_pred).argmax(axis=1) == to_np(yt))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will use the optimizer to calculate which direction to step in. That is, how should we update our weights to try to decrease the loss?\n",
"\n",
"Pytorch has an automatic differentiation package ([autograd](http://pytorch.org/docs/master/autograd.html)) that takes derivatives for us, so we don't have to calculate the derivative ourselves! We just call `.backward()` on our loss to calculate the direction of steepest descent (the direction to lower the loss the most)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Before the backward pass, use the optimizer object to zero all of the\n",
"# gradients for the variables it will update (which are the learnable weights\n",
"# of the model)\n",
"optimizer.zero_grad()\n",
"\n",
"# Backward pass: compute gradient of the loss with respect to model parameters\n",
"l.backward()\n",
"\n",
"# Calling the step function on an Optimizer makes an update to its parameters\n",
"optimizer.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's make another set of predictions and check if our loss is lower:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xt, yt = next(dl)\n",
"y_pred = net2(Variable(xt).cuda())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
" 2.2265\n",
"[torch.cuda.FloatTensor of size 1 (GPU 0)]\n",
"\n"
]
}
],
"source": [
"l = loss(y_pred, Variable(yt).cuda())\n",
"print(l)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that we are using **stochastic** gradient descent, so the loss is not guaranteed to be strictly better each time. The stochasticity comes from the fact that we are using **mini-batches**; we are just using 64 images to calculate our prediction and update the weights, not the whole dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.15625"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(to_np(y_pred).argmax(axis=1) == to_np(yt))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we run several iterations in a loop, we should see the loss decrease and the accuracy increase with time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 2.2104923725128174 \t accuracy: 0.234375\n",
"loss: 1.3094730377197266 \t accuracy: 0.625\n",
"loss: 1.0296542644500732 \t accuracy: 0.78125\n",
"loss: 0.8841525316238403 \t accuracy: 0.71875\n",
"loss: 0.6643403768539429 \t accuracy: 0.8125\n",
"loss: 0.5525785088539124 \t accuracy: 0.875\n",
"loss: 0.43296846747398376 \t accuracy: 0.890625\n",
"loss: 0.4388267695903778 \t accuracy: 0.90625\n",
"loss: 0.39874207973480225 \t accuracy: 0.890625\n",
"loss: 0.4848807752132416 \t accuracy: 0.875\n"
]
}
],
"source": [
"for t in range(100):\n",
" xt, yt = next(dl)\n",
" y_pred = net2(Variable(xt).cuda())\n",
" l = loss(y_pred, Variable(yt).cuda())\n",
" \n",
" if t % 10 == 0:\n",
" accuracy = np.mean(to_np(y_pred).argmax(axis=1) == to_np(yt))\n",
" print(\"loss: \", l.data[0], \"\\t accuracy: \", accuracy)\n",
"\n",
" optimizer.zero_grad()\n",
" l.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Put it all together in a training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def score(x, y):\n",
" y_pred = to_np(net2(V(x)))\n",
" return np.sum(y_pred.argmax(axis=1) == to_np(y))/len(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.908738057325\n"
]
}
],
"source": [
"net2 = LogReg().cuda()\n",
"loss=nn.NLLLoss()\n",
"learning_rate = 1e-2\n",
"optimizer=optim.SGD(net2.parameters(), lr=learning_rate)\n",
"\n",
"for epoch in range(1):\n",
" losses=[]\n",
" dl = iter(md.trn_dl)\n",
" for t in range(len(dl)):\n",
" # Forward pass: compute predicted y and loss by passing x to the model.\n",
" xt, yt = next(dl)\n",
" y_pred = net2(V(xt))\n",
" l = loss(y_pred, V(yt))\n",
" losses.append(l)\n",
"\n",
" # Before the backward pass, use the optimizer object to zero all of the\n",
" # gradients for the variables it will update (which are the learnable weights of the model)\n",
" optimizer.zero_grad()\n",
"\n",
" # Backward pass: compute gradient of the loss with respect to model parameters\n",
" l.backward()\n",
"\n",
" # Calling the step function on an Optimizer makes an update to its parameters\n",
" optimizer.step()\n",
" \n",
" val_dl = iter(md.val_dl)\n",
" val_scores = [score(*next(val_dl)) for i in range(len(val_dl))]\n",
" print(np.mean(val_scores))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Stochastic Gradient Descent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Nearly all of deep learning is powered by one very important algorithm: **stochastic gradient descent (SGD)**. SGD can be seeing as an approximation of **gradient descent (GD)**. In GD you have to run through all the samples in your training set to do a single itaration. In SGD you use only a subset of training samples to do the update for a parameter in a particular iteration. The subset used in each iteration is called a batch or minibatch.\n",
"\n",
"Now, instead of using the optimizer, we will do the optimization ourselves!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.908837579618\n"
]
}
],
"source": [
"net2 = LogReg().cuda()\n",
"loss_fn=nn.NLLLoss()\n",
"lr = 1e-2\n",
"w,b = net2.l1_w,net2.l1_b\n",
"\n",
"for epoch in range(1):\n",
" losses=[]\n",
" dl = iter(md.trn_dl)\n",
" for t in range(len(dl)):\n",
" xt, yt = next(dl)\n",
" y_pred = net2(V(xt))\n",
" l = loss(y_pred, Variable(yt).cuda())\n",
" losses.append(loss)\n",
"\n",
" # Backward pass: compute gradient of the loss with respect to model parameters\n",
" l.backward()\n",
" w.data -= w.grad.data * lr\n",
" b.data -= b.grad.data * lr\n",
" \n",
" w.grad.data.zero_()\n",
" b.grad.data.zero_() \n",
"\n",
" val_dl = iter(md.val_dl)\n",
" val_scores = [score(*next(val_dl)) for i in range(len(val_dl))]\n",
" print(np.mean(val_scores))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}