{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# All the Linear Algebra You Need for AI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The purpose of this notebook is to serve as an explanation of two crucial linear algebra operations used when coding neural networks: matrix multiplication and broadcasting." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "**Matrix multiplication** is a way of combining two matrices (involving multiplying and summing their entries in a particular way). **Broadcasting** refers to how libraries such as Numpy and PyTorch can perform operations on matrices/vectors with mismatched dimensions (in particular cases, with set rules). We will use broadcasting to show an alternative way of thinking about matrix multiplication from, different from the way it is standardly taught." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "In keeping with the [fast.ai teaching philosophy](http://www.fast.ai/2016/10/08/teaching-philosophy/) of [\"the whole game\"](https://www.amazon.com/Making-Learning-Whole-Principles-Transform/dp/0470633719/ref=sr_1_1?ie=UTF8&qid=1505094653), we will:\n", "\n", "- first use a pre-defined class for our neural network\n", "- then define the net ourselves to see where it uses matrix multiplication & broadcasting\n", "- and finally dig into the details of how those operations work\n", "\n", "This is different from how most math courses are taught, where you have to learn all the individual elements before you can combine them (Harvard professor David Perkins call this *elementitis*), but it is similar to how topics like *driving* and *baseball* are taught. That is, you can start driving without [knowing how an internal combustion engine works](https://medium.com/towards-data-science/thoughts-after-taking-the-deeplearning-ai-courses-8568f132153), and children begin playing baseball before they learn all the formal rules.\n", "\n", "\"\"\n", "
\n", "(source: [Demba Ba](https://github.com/zalandoresearch/fashion-mnist) and [Arvind Nagaraj](https://medium.com/towards-data-science/thoughts-after-taking-the-deeplearning-ai-courses-8568f132153))\n", "
" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "### More linear algebra resources" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "This notebook was originally created for a 40 minute talk I gave at the [O'Reilly AI conference in San Francisco](https://conferences.oreilly.com/artificial-intelligence/ai-ca). If you want further resources for linear algebra, here are a few recommendations:\n", "\n", "- [3Blue1Brown Essence of Linear Algebra](https://www.youtube.com/playlist?list=PLZHQObOWTQDPD3MizzM2xVFitgF8hE_ab) videos about *geometric intuition*, which are gorgeous and great for visual learners\n", "- [Khan Academy Linear Algebra](https://www.khanacademy.org/math/linear-algebra) videos covering traditional linear algebra material\n", "- [Immersive linear algebra](http://immersivemath.com/ila/) free online textbook with interactive graphics\n", "- [Chapter 2](http://www.deeplearningbook.org/contents/linear_algebra.html) of Ian Goodfellow's Deep Learning Book for a fairly academic take\n", "- [Computational Linear Algebra](http://www.fast.ai/2017/07/17/num-lin-alg/): a free, online fast.ai course, originally taught in the University of San Francisco's Masters in Analytics program. It includes a free [online textbook](https://github.com/fastai/numerical-linear-algebra/blob/master/README.md) and [series of videos](https://www.youtube.com/playlist?list=PLtmWHNX-gukIc92m1K0P6bIOnZb-mg0hY). This course is very different from standard linear algebra (which often focuses on how **humans** do matrix calculations), because it is about how to get **computers** to do matrix computations with speed and accuracy, and incorporates modern tools and algorithms. All the material is taught in Python and centered around solving practical problems such as removing the background from a surveillance video or implementing Google's PageRank search algorithm on Wikipedia pages." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our Tools" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will be using the open source [deep learning library, fastai](https://github.com/fastai/fastai), which provides high level abstractions and best practices on top of PyTorch. This is the highest level, simplest way to get started with deep learning. Please note that fastai requires Python 3 to function. It is currently in pre-alpha, so items may move around and more documentation will be added in the future." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from fastai.imports import *\n", "from fastai.torch_imports import *\n", "from fastai.io import *" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### PyTorch" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The fastai deep learning library uses [PyTorch](http://pytorch.org/), a Python framework for dynamic neural networks with GPU acceleration, which was released by Facebook's AI team.\n", "\n", "PyTorch has two overlapping, yet distinct, purposes. As described in the [PyTorch documentation](http://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html):\n", "\n", "\"pytorch\"\n", "\n", "The neural network functionality of PyTorch is built on top of the Numpy-like functionality for fast matrix computations on a GPU. Although the neural network purpose receives way more attention, both are very useful. We'll implement a neural net from scratch today using PyTorch.\n", "\n", "**Further learning**: If you are curious to learn what *dynamic* neural networks are, you may want to watch [this talk](https://www.youtube.com/watch?v=Z15cBAuY7Sc) by Soumith Chintala, Facebook AI researcher and core PyTorch contributor.\n", "\n", "If you want to learn more PyTorch, you can try this [introductory tutorial](http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) or this [tutorial to learn by examples](http://pytorch.org/tutorials/beginner/pytorch_with_examples.html)." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### About GPUs" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Graphical processing units (GPUs) allow for matrix computations to be done with much greater speed, as long as you have a library such as PyTorch that takes advantage of them. Advances in GPU technology in the last 10-20 years have been a key part of why neural networks are proving so much more powerful now than they did a few decades ago. \n", "\n", "You may own a computer that has a GPU which can be used. For the many people that either don't have a GPU (or have a GPU which can't be easily accessed by Python), there are a few differnt options:\n", "\n", "- **Don't use a GPU**: For the sake of this tutorial, you don't have to use a GPU, although some computations will be slower. The only change needed to the code is to remove `.cuda()` wherever it appears.\n", "- **Use crestle, through your browser**: [Crestle](https://www.crestle.com/) is a service that gives you an already set up cloud service with all the popular scientific and deep learning frameworks already pre-installed and configured to run on a GPU in the cloud. It is easily accessed through your browser. New users get 10 hours and 1 GB of storage for free. After this, GPU usage is 34 cents per hour. I recommend this option to those who are new to AWS or new to using the console.\n", "- **Set up an AWS instance through your console**: You can create an AWS instance with a GPU by following the steps in this [fast.ai setup lesson](http://course.fast.ai/lessons/aws.html).] AWS charges 90 cents per hour for this." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### About The Data" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Today we will be working with MNIST, a classic data set of hand-written digits. Solutions to this problem are used by banks to automatically recognize the amounts on checks, and by the postal service to automatically recognize zip codes on mail." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "\"\"" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "A matrix can represent an image, by creating a grid where each entry corresponds to a different pixel.\n", "\n", "\"digit\"\n", " (Source: [Adam Geitgey\n", "](https://medium.com/@ageitgey/machine-learning-is-fun-part-3-deep-learning-and-convolutional-neural-networks-f40359318721))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's download, unzip, and format the data." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "path = '../data/'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import os\n", "os.makedirs(path, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "URL='http://deeplearning.net/data/mnist/'\n", "FILENAME='mnist.pkl.gz'\n", "\n", "def load_mnist(filename):\n", " return pickle.load(gzip.open(filename, 'rb'), encoding='latin-1')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "get_data(URL+FILENAME, path+FILENAME)\n", "((x, y), (x_valid, y_valid), _) = load_mnist(path+FILENAME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Normalize" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many machine learning algorithms behave better when the data is *normalized*, that is when the mean is 0 and the standard deviation is 1. We will subtract off the mean and standard deviation from our training set in order to normalize the data:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mean = x.mean()\n", "std = x.std()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-3.1638146e-07, 0.99999934)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=(x-mean)/std\n", "x.mean(), x.std()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that for consistency (with the parameters we learn when training), we subtract the mean and standard deviation of our training set from our validation set. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(-0.0058509219, 0.99243325)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_valid = (x_valid-mean)/std\n", "x_valid.mean(), x_valid.std()" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Look at the data" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "In any sort of data science work, it's important to look at your data, to make sure you understand the format, how it's stored, what type of values it holds, etc. To make it easier to work with, let's reshape it into 2d images from the flattened 1d format." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "#### Helper methods" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "def show(img, title=None):\n", " plt.imshow(img, interpolation='none', cmap=\"gray\")\n", " if title is not None: plt.title(title)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def plots(ims, figsize=(12,6), rows=2, titles=None):\n", " f = plt.figure(figsize=figsize)\n", " cols = len(ims)//rows\n", " for i in range(len(ims)):\n", " sp = f.add_subplot(rows, cols, i+1)\n", " sp.axis('Off')\n", " if titles is not None: sp.set_title(titles[i], fontsize=16)\n", " plt.imshow(ims[i], interpolation='none', cmap='gray')" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "#### Plots " ] }, { "cell_type": "code", "execution_count": 97, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(10000, 784)" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_valid.shape" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(10000, 28, 28)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_imgs = np.reshape(x_valid, (-1,28,28)); x_imgs.shape" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "hidden": true, "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPoAAAEHCAYAAACHl1tOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEDpJREFUeJzt3X2MVfWdx/H3+BSefBjRLcKiRJh8t80diWX/KLu4RdRq\njCtRqNtUDTuSmE3UGB/IInUTxaRqjcGIrqZRi5V0Y4GkasWHiqYKTVhjcDPT1K/oNmJ4EAGrUpEF\ndvaPuczee2fOuXfuPffeM/1+Xskk93d+95zz5cx8OE/33F9Hf38/IvKX7Zh2FyAizaegiwSgoIsE\noKCLBKCgiwSgoIsEcFy7C5B8MbMFwL8BY4A9wL+4e197q5JGaY8ug8zsTOBxYL67/w2wBniqvVVJ\nFhR0KXUI+KG7f1RsbwCsjfVIRnToLoPcfSewE8DMjgP+GXiunTVJNrRHlyHM7GbgE+A84F/bXI5k\noEOfdZfhmFkH8APgx8C33P1Am0uSBmiPLoPM7JtmdiGAu/e7+38AJ6Hz9FFPQZdSpwM/N7PJAGb2\n98DxwH+3tSppmA7dpYyZ3QDcwMBO4CBwh7uvb29V0igFXSQAHbqLBKCgiwSgoIsEoKCLRNDf39/0\nH6C/9Ke3t7e/clpeflSbahutdaVlsO6r7ma2AvhOcSU3u/vbSe/t6OgoW0l/fz8dHR11rbfZVFt9\nVNvIZV1Xf39/4sLqOnQ3s+8CXe4+G1gMPFxnbSLSAvWeo18A/ArA3f8AdJrZSZlVJSKZqvcx1UnA\nOyXtT4vTvhjuzb29vRQKhbJpef6gjmqrj2obuVbVldXz6KknGt3d3WXtvJ4zgWqrl2obuSacoyf2\n1XvovoOBPfhRkyl+YYGI5E+9QX8VWAhgZt8Gdrj7l5lVJSKZqivo7v474B0z+x0DV9xvyLQqEclU\nS55e0330bKi2+uS1ttzfRxeR0UVBFwlAQRcJQEEXCUBBFwlAQRcJQEEXCUBBFwlAQRcJQEEXCUBB\nFwlAQRcJQEEXCUBBFwlAQRcJQEEXCUBBFwlAQRcJQEEXCUBBFwlAQRcJIKuRWqQJZs6cmTrtlltu\nSZx3+vTpqcseN25cav+yZctS+08++eQh06666qrB1y+99FLivF9+qSEAWk17dJEAFHSRABR0kQAU\ndJEAFHSRABR0kQAUdJEANJpqhVbWNmHChNT+bdu2lbU7Ozv57LPPBtunnHJKU+qqR0dHB6V/S9u3\nb098b9r9f4C1a9dmVhfk9++tlaOp1vWBGTObC6wBfl+c1OvuN9WzLBFpvkY+Gfdbd1+YWSUi0jQ6\nRxcJoK5z9OKh+78DHwCnAne7+2+S3t/X19dfKBTqrVFEapN4jl5v0KcAc4BfAmcDbwAz3P1/hl2J\nLsYNSxfjhqeLcXUvL9uLce6+HXi22PzQzHYBU4A/1rM8EWmuus7RzexqM7u9+HoS8A0g+b9wEWmr\neg/dTwR+AZwCnMDAOfr6xJXo0H1YJ554Ymr/+vXlm3TOnDls3LhxsL13797Eebds2ZK67HPPPTe1\n/6yzzkrtnzp1all74sSJZfWMHTs2cd5PPvkkddmzZ89O7a82f6W8/r2NhkP3L4F/rLsiEWkp3V4T\nCUBBFwlAQRcJQEEXCUBBFwlAj6lWUG21Oe2008ran376Kaeffvpge8mSJYnzpvUB9PT0pPY//fTT\nNVT4//K03Uq18vaa9ugiASjoIgEo6CIBKOgiASjoIgEo6CIBKOgiAWjYZKnLnj17Uqdt2rQpcd5q\n99GrPUI70vvooj26SAgKukgACrpIAAq6SAAKukgACrpIAAq6SAC6jy516ezsTJ22bNmyupc9efLk\nuueV4WmPLhKAgi4SgIIuEoCCLhKAgi4SgIIuEoCCLhKA7qPLsGbOnJnav2bNmiHTNm/ePPh6xowZ\nifO+//77qcu+7bbbqlQnI1VT0M2sADwHrHD3R8xsKvAMcCywE7jW3Q82r0wRaUTVQ3czGw+sBDaU\nTF4OPOru5wEfANc1pzwRyUIt5+gHgUuBHSXT5gLPF1+/AFyYbVkikqWqh+7ufhg4bGalk8eXHKrv\nBs5IW0Zvby+FQqFsWivGfKuXaqtPV1dXTe+r+FsaYtu2bVmUUyav261VdWVxMa7qKHHd3d1l7bwO\negeq7aiRXozr6upi69atg+1GLsZddNFFqf0ff/xxan+lvP5OmzDIYmJfvbfX9pvZ2OLrKZQf1otI\nztQb9NeABcXXC4CXsylHRJqh6vjoZjYLeBCYBhwCtgNXA6uAMcBHQI+7H0pcicZHz0SWtS1atCi1\nf/ny5an9U6dOLWt3dHSUHToeOHAgcd7LLrssddlvvPFGav9I5fV32srx0Wu5GPcOA1fZK6WfSIlI\nbugjsCIBKOgiASjoIgEo6CIBKOgiAegx1VFswoQJiX2333576rx33nlnav8xx6TvA/bt21fWnjhx\nYtm0OXPmJM773nvvpS5bsqc9ukgACrpIAAq6SAAKukgACrpIAAq6SAAKukgAuo8+iq1atSqx78or\nr2xo2WvXrk3tf+ihh8ramzZt4vLLLx9s6155vmiPLhKAgi4SgIIuEoCCLhKAgi4SgIIuEoCCLhJA\n1a97zmQl+rrnTFTWtmXLlsT3VhtppZoLLrggtb/yK5lH03bLi1Z+3bP26CIBKOgiASjoIgEo6CIB\nKOgiASjoIgEo6CIB6Hn0UezVV19N7Gv0PnrasgEee+yxIdMefvjhwdf33Xdf4rw7duyovzCpS01B\nN7MC8Bywwt0fMbNVwCxgb/EtD7j7i80pUUQaVTXoZjYeWAlsqOi6w91/3ZSqRCRTtZyjHwQuBXS8\nJTJK1fxZdzO7C9hTcug+CTgB2A3c6O57kubt6+vrLxQKjVcrImkSP+te78W4Z4C97v6umS0F7gJu\nTHpzd3d3WTuvDxnA6Krt/vvvT3zvkiVLGlrXkSNHUvsrL8bddNNNrFy5crCdp4txef2dNuGhlsS+\nuoLu7qXn688DQy/Bikhu1HUf3czWmdnZxeZcoC+zikQkc1XP0c1sFvAgMA04BGxn4Cr8UuArYD/Q\n4+67E1ei59EzUVnb2LFjE9+7evXq1GXNmjUrtf/MM88cUW0dHR1lh467du1KfG9PT0/qsl555ZUR\nrbuavP5OW/k8etVDd3d/h4G9dqV1DdQkIi2kj8CKBKCgiwSgoIsEoKCLBKCgiwSgr3uu8JdS25gx\nY1L7jzsu/YbLF198UXNdMPT2Wpqvv/46tf/WW29N7X/88cdrrgvy+zvV1z2LSKYUdJEAFHSRABR0\nkQAUdJEAFHSRABR0kQB0H72CahtwzjnnpPavWLGirD1v3jxef/31wfb5559f97q3bduW2j9t2rQR\nLS+vv1PdRxeRTCnoIgEo6CIBKOgiASjoIgEo6CIBKOgiAeg+eoVW1jZu3LjU/q+++qqsnaft1tnZ\nWdbet28fp5566mD7qaeeSpx3/vz5Da17ypQpqf07d+4sa+dpu5XSfXQRyZSCLhKAgi4SgIIuEoCC\nLhKAgi4SgIIuEkDV0VQBzOwnwHnF998LvA08AxwL7ASudfeDzSpytJo+fXpq/8aNG1P7X3zxxSHT\nnnjiicHXfX3Jw9JX3kuutHjx4tT+448/PrV/uHvZmzdvHnw9Y8aM1PnTfPjhh6n91f5tMlTVPbqZ\nnQ8U3H02cAnwELAceNTdzwM+AK5rapUi0pBaDt3fBL5ffP0nYDwD46U/X5z2AnBh5pWJSGaqHrq7\n+xHgz8XmYmA9cHHJofpu4IzmlCciWaj5s+5mNh9YBnwP2Oruf1WcPgP4ubv/XdK8fX19/YVCIYNy\nRSRF4mfda70YdzHwI+ASd//czPab2Vh3PwBMAXakzd/d3V3WzutDBpBtbVlfjFu8eDFPPvnkYDtP\nF+O6urrYunXrYLuZF+O6urpGtLy8/r014aGWxL5aLsadDDwAXObu+4qTXwMWFF8vAF5usEYRaaKq\nh+5mdj1wF/B+yeRFwBPAGOAjoMfdDyWuJOhjqkuXLk3tv/fee6vWUmokQxM3qto2aKS2/fv3p/Zf\nccUVqf0bNmyoaT1H5fXvrZWPqdZyMe6nwE+H6bqokaJEpHX0yTiRABR0kQAUdJEAFHSRABR0kQAU\ndJEAavpknNRn4sSJ7S6hadatW1fWXrhwYdm0e+65J3He3bt3py57165djRUnQ2iPLhKAgi4SgIIu\nEoCCLhKAgi4SgIIuEoCCLhKAhk2ukGVt1b6lZd68ean911xzzZD26tWrB9uTJ09OnPfzzz+vocJk\nK1euTO1/6623ytqHDh0q+/cePny4ofVnKa9/bxo2WUQypaCLBKCgiwSgoIsEoKCLBKCgiwSgoIsE\noPvoFVRbfVTbyOk+uohkSkEXCUBBFwlAQRcJQEEXCUBBFwlAQRcJoKbvdTeznwDnFd9/L3A5MAvY\nW3zLA+7+YlMqFJGGVQ26mZ0PFNx9tplNBLYArwN3uPuvm12giDSulj36m8B/Fl//CRgPHNu0ikQk\ncyP6CKyZXc/AIfwRYBJwArAbuNHd9yTN19fX118oFBosVUSqSPwIbM1BN7P5wDLge8DfAnvd/V0z\nWwr8tbvfmLgSfdY9E6qtPnmtrZWfda/1YtzFwI+AS9z9c2BDSffzwGMNVSgiTVX19pqZnQw8AFzm\n7vuK09aZ2dnFt8wF+ppWoYg0rJY9+j8BpwG/NLOj034GPGtmXwH7gZ7mlCciWdDz6BVUW31U28jp\neXQRyZSCLhKAgi4SgIIuEoCCLhKAgi4SgIIuEoCCLhKAgi4SgIIuEoCCLhKAgi4SgIIuEoCCLhJA\nSx5TFZH20h5dJAAFXSQABV0kAAVdJAAFXSQABV0kAAVdJICaRmrJkpmtAL4D9AM3u/vbra5hOGY2\nF1gD/L44qdfdb2pfRWBmBeA5YIW7P2JmU4FnGBjkcidwrbsfzEltq8jJUNrDDPP9NjnYbu0cfryl\nQTez7wJdxSGYvwk8BcxuZQ1V/NbdF7a7CAAzGw+spHz4q+XAo+6+xsx+DFxHG4bDSqgNcjCUdsIw\n3xto83Zr9/DjrT50vwD4FYC7/wHoNLOTWlzDaHEQuBTYUTJtLgNj3QG8AFzY4pqOGq62vHgT+H7x\n9dFhvufS/u02XF0tG3681Yfuk4B3StqfFqd90eI6knzLzJ4HTgXudvfftKsQdz8MHC4ZBgtgfMkh\n527gjJYXRmJtADea2a3UMJR2E2s7Avy52FwMrAcubvd2S6jrCC3aZu2+GJencXK2AncD84FFwJNm\ndkJ7S0qVp20HA+fAS919HvAucFc7iykO870YqBzOu63braKulm2zVu/RdzCwBz9qMgMXR9rO3bcD\nzxabH5rZLmAK8Mf2VTXEfjMb6+4HGKgtN4fO7p6bobQrh/k2s1xst3YOP97qPfqrwEIAM/s2sMPd\nv2xxDcMys6vN7Pbi60nAN4Dt7a1qiNeABcXXC4CX21hLmbwMpT3cMN/kYLu1e/jxlj+mamb3Af8A\n/C9wg7v/V0sLSGBmJwK/AE4BTmDgHH19G+uZBTwITAMOMfCfztXAKmAM8BHQ4+6HclLbSmApMDiU\ntrvvbkNt1zNwCPx+yeRFwBO0cbsl1PUzBg7hm77N9Dy6SADtvhgnIi2goIsEoKCLBKCgiwSgoIsE\noKCLBKCgiwTwfzrYKoW62D0ZAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show(x_imgs[0], y_valid[0])" ] }, { "cell_type": "code", "execution_count": 117, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(10000,)" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_valid.shape" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "It's the digit 3! And that's stored in the y value:" ] }, { "cell_type": "code", "execution_count": 100, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_valid[0]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "We can look at part of an image:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[-0.4245, -0.4245, -0.4245, -0.4245, 0.1729],\n", " [-0.4245, -0.4245, -0.4245, 0.7831, 2.4357],\n", " [-0.4245, -0.272 , 1.2026, 2.7789, 2.8043],\n", " [-0.4245, 1.7619, 2.8043, 2.8043, 1.7365],\n", " [-0.4245, 2.2069, 2.8043, 2.8043, 0.4018]], dtype=float32)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_imgs[0,10:15,10:15]" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "hidden": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPQAAAD4CAYAAADb7cuFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACg9JREFUeJzt3V2InOUZgOF7ogf+REVoq4hLpSKPyCwRPdGCWqO01p8K\nYaxQqBEsPVGwUISCqag5EBR/0CK2UFlKT5RVLGoPpM2BQlqwgZVdKM+BWKpEaEXUUIpoMz3YXbB2\nM/Nl9vt2dh7u62hnMr7zbNib95vZ+E5vOBwiqYYd0x5AUnsMWirEoKVCDFoqxKClQk5se8Fer9fJ\n2+bLy8vMz893sXTrZmlWmK15u5p1165dra8JsLi4yGAwaH3dpaWl3kb399r+tVVXQQ+HQ3q9Db+H\nbWeWZoXZmrerWbsKemlpiYsvvriLdTf8S/CSWyrEoKVCDFoqxKClQgxaKsSgpUIMWirEoKVCDFoq\nxKClQgxaKsSgpUIMWirEoKVCDFoqxKClQgxaKqTREUQR8ThwGTAE7s7MNzudStJExu7QEXEVcEFm\nXg7cATzZ+VSSJtLkkvsa4CWAzPwrcGZEnN7pVJIm0uSS+2zg0Bdu/3Ptvk82evDy8jL9fr+F0f7f\nLH0O1yzNCrM17yzNCqsHBbZp1KGDkxzjO/LIxa6Og/Vkyu7M0rye+jlak0vuw6zuyOvOAd7vZhxJ\nm9Ek6NeAAUBEXAIczswjnU4laSJjg87Mg8ChiDjI6jvcd3Y+laSJNHoNnZk/63oQSZvnvxSTCjFo\nqRCDlgoxaKkQg5YKMWipEIOWCjFoqRCDlgoxaKkQg5YKMWipEIOWCjFoqRCDlgoxaKmQSQ4JlBrZ\nu3fvzKy7f//+1tdc9/LLL3e29pe5Q0uFGLRUiEFLhRi0VIhBS4UYtFSIQUuFGLRUiEFLhRi0VIhB\nS4UYtFSIQUuFGLRUiEFLhRi0VIhBS4U0Cjoi+hHxdkTc1fVAkiY3NuiIOBV4Cvhj9+NI2owmO/Sn\nwPXA4Y5nkbRJYw8JzMzPgc8jYgvGkbQZrZ/6uby8TL/fb3tZAIbDYSfrdmGWZoXZmndhYWHaIxyX\nubm5Vtd79913j/lnrQc9Pz/f9pLA6g9cr9frZO22zdKs0N28XRy3u7CwwO233976ul0d4zs3Nzcy\nwLb5ayupkLE7dERcCjwKnAd8FhEDYE9mftjxbJKOU5M3xQ4B3+p+FEmb5SW3VIhBS4UYtFSIQUuF\nGLRUiEFLhRi0VIhBS4UYtFSIQUuFGLRUiEFLhRi0VIhBS4UYtFSIQUuFtH6m2CzZuXPnTK19zz33\ntL7mugceeKD1Nfft29f6mgDPPvts62vu2NHd3tb2IYGjuENLhRi0VIhBS4UYtFSIQUuFGLRUiEFL\nhRi0VIhBS4UYtFSIQUuFGLRUiEFLhRi0VIhBS4UYtFSIQUuFGLRUSKMjiCLiYeCKtcc/lJkvdjqV\npImM3aEj4mqgn5mXA9cBT3Q+laSJNLnkfh24Ze3rj4BTI+KE7kaSNKnecDhs/OCI+DFwRWb+8FiP\nWVlZGfb7/TZmk3RsvY3ubHyMb0TcDNwBfHvU4+bn549vrIaGwyG93obfw8S6Osb3yJEjnHbaaa2v\n29Uxvvfddx8PPvhg6+t2cYzvjh07OHr0aCfrVtD0TbHvAPcC12Xmx92OJGlSY4OOiDOAR4BrM/PD\n7keSNKkmO/StwFeA5yNi/b7bMvPvnU0laSJjg87MXwG/2oJZJG1SjXcCJAEGLZVi0FIhBi0VYtBS\nIQYtFWLQUiEGLRVi0FIhBi0VYtBSIQYtFWLQUiEGLRVi0FIhBi0VclynfjZasNdrd8E1XRwS+MIL\nL7S63ro9e/bw4ovtfxbBnj17Wl9TqxYXFztZdzAYdLL2YDDYMAZ3aKkQg5YKMWipEIOWCjFoqRCD\nlgoxaKkQg5YKMWipEIOWCjFoqRCDlgoxaKkQg5YKMWipEIOWCjFoqZATxz0gIk4BFoCzgJOA/Zn5\nSsdzSZpAkx36JuAvmXkV8H3gsW5HkjSpsTt0Zj73hZtzwHvdjSNpM8YGvS4iDgLnAjd2N46kzTiu\nUz8j4mLgN8CuzNzwP1xZWRn2+/2WxpP0ZYuLi8c89bPJm2KXAv/IzHczcykiTgS+Cvxjo8fPz89v\nathj8Rhfj/Ht0qwd43ssTd4UuxL4KUBEnAXsBD7ocihJk2kS9DPA1yLiDeBV4M7MPNrtWJIm0eRd\n7n8DP9iCWSRtkv9STCrEoKVCDFoqxKClQgxaKsSgpUIMWirEoKVCDFoqxKClQgxaKsSgpUIMWirE\noKVCDFoqxKClQhqf+lnR+eefP5Nrq31PP/10J+sOBoNO1h4MBhve7w4tFWLQUiEGLRVi0FIhBi0V\nYtBSIQYtFWLQUiEGLRVi0FIhBi0VYtBSIQYtFWLQUiEGLRVi0FIhBi0VYtBSIY2CjoiTI+LtiLi9\n43kkbULTHXof8GGXg0javLFBR8SFwEXAq92PI2kzesPhcOQDIuJV4C5gL/C3zFwY9fiVlZVhv99v\nbUBJ/2v37t0cOHCgt9GfjTzGNyJuA/6Ume9ERKMnm5+fP/4JGxgOh/R6G34PE1taWmp1vXW7du3i\nrbfe6mRddWP37t2drHvgwIHO1t7IuHO5bwC+ERE3AucCn0bEe5n5h+5Hk3S8Rgadmbeufx0R97N6\nyW3M0jbl76GlQhp/FE5m3t/hHJJa4A4tFWLQUiEGLRVi0FIhBi0VYtBSIQYtFWLQUiEGLRVi0FIh\nBi0VYtBSIQYtFWLQUiEGLRVi0FIhY0/9lDQ73KGlQgxaKsSgpUIMWirEoKVCDFoqxKClQhoftD8t\nEfE4cBkwBO7OzDenPNJIEdEHfgc8npm/mPY8o0TEw8AVrP4cPJSZL055pGOKiFOABeAs4CRgf2a+\nMtWhxoiIk4EVVmdd2Irn3NY7dERcBVyQmZcDdwBPTnmkkSLiVOAp4I/TnmWciLga6K/93V4HPDHl\nkca5CfhLZl4FfB94bMrzNLEP+HArn3BbBw1cA7wEkJl/Bc6MiNOnO9JInwLXA4enPUgDrwO3rH39\nEXBqRJwwxXlGysznMvPhtZtzwHvTnGeciLgQuAh4dSufd7tfcp8NHPrC7X+u3ffJdMYZLTM/Bz5v\n+lna05SZ/wH+tXbzDuD3a/dtaxFxkNWPNr5x2rOM8ShwF7B3K590u+/QX9buJ76LiLiZ1aDvmvYs\nTWTmN4HvAb+NiG358xARtwF/ysx3tvq5t3vQh1ndkdedA7w/pVnKiYjvAPcC383Mj6c9zygRcWlE\nzAFk5hKrV5dfne5Ux3QDcHNE/Bn4EfDziLh2K554u19yvwY8APwyIi4BDmfmkSnPVEJEnAE8Alyb\nmVv6xs2ErgS+DvwkIs4CdgIfTHekjWXmretfR8T9wN8y8w9b8dzbOujMPBgRh9ZeNx0F7pz2TKNE\nxKWsvnY6D/gsIgbAnm0azK3AV4Dnv/Ca/7bM/Pv0RhrpGeDXEfEGcDJwZ2YenfJM247/P7RUyHZ/\nDS3pOBi0VIhBS4UYtFSIQUuFGLRUiEFLhfwXImQs9VpbtwwAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show(x_imgs[0,10:15,10:15])" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAr4AAAF0CAYAAADFHDo6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl0ldW9//FPGFJGRVQQUUDBbhdUAcGpUusMCBekYguK\nlUnFoVa0ViYVRIu2KqAo3paKOBYQwQEHitAqsmq1oCxUHtGWkiqTIEIAISH5/XHouvz47kNOck7O\nk5P9fq11V3o/PMPXdid8fXK+z84rLS0VAAAAUN3ViLsAAAAAIBtofAEAABAEGl8AAAAEgcYXAAAA\nQaDxBQAAQBBofAEAABCEWnEXEDLnXCNJd0vqI6mppAJJ0yX9JoqikjhrAw7GOZcvaYSkyyW1krRJ\n0jRJ90VRtDvG0oAyOefOlPSApFMkfSNphqTR/NxFVeacqyXpLklXKdEzfCppZBRFr8daWI7hiW+8\nZkrqKmmgpBMlTVSiER4eY01AKu6TdIukUZLaSvqlEut2QpxFAWVxzrWV9GdJryuxdm+WdJOk2+Os\nC0jBQ5JulTROibX7hqSXnHMdY60qx+SxgUU8nHPHSloh6fL9/23NObdAUsMois6MrTigDM65TZKe\njaLo5v2yiUqs56bxVQYcnHPueUm1oii6bL/sIknfRlH0XnyVAck55+pK2ippYhRFI/bLl0gqiKKo\nf2zF5Rg+6hCTKIoKJB2W5I+Ls1kLUAGlsut0974cqJKcczUk9ZA0ZP88iqIF8VQEpKyNpHxJ7xyQ\nvyLpV9kvJ3fR+FYRzrnakgZI+pGkfjGXA5TlMUnDnHMzJX2gxK/dBkj631irAg6ulaSGkgqdc7Ml\nnS2pUNLDURRNjrMwoAw193098IHDJklHOOcOiaJoW5Zrykl8xrcKcM4tlfSdEp+b7BdF0UsxlwSU\nZZykOZL+rsST3pWSXpU0NsaagLIcue/rZElvSeom6QlJDzrnRsVWFVC2zyXtldTpgLz9vq8Ns1tO\n7qLxrRp+JqmzpKmSZjrnroi5HqAstyuxbgdKOlXSlUq8neTuGGsCylJ739dnoyh6PIqi5VEU3Svp\neUm/dM7lxVgbkFQURYWSnpF0i3PuLOdcTedcb0k/3XdIUXzV5RY+6lAF7Pu8b4Gk5c65BpIecc49\nz6t1UBU55xpr39tHoiiasS/+yDlXR9JU59zkKIq+jq9CIKnt+74uOyBfosRHdZpKWp/VioDU3STp\nECXWa4mkdyXdqcRHz76Jsa6cwhPfmDjnWjrnrtj3Xr79rVRi6K1JDGUBqWitxJOzVQfknyvxL9PH\nZb0iIDVfKNEwND4g/+/fhXxGElVWFEXboij6iaQjJDWNouhsSfUkrYqiiCe+KaLxjc8JSvza4uwD\n8pMk7ZK0JesVAan5z76v3z8gP/GAPweqlH2/Ln5X0v8c8EdnSfoiiqKd2a8KSI1zro9z7rQoijbv\n91u1fpKYCyoHPuoQn8VKTMP/3jl3o6TVks6RdJ2kJ6Io2hNjbUBSURStc87NkXSnc26dEu+jbivp\nDkkLoihaF2uBwMGNk7TAOTdS0ixJPZX4vPr1sVYFlO1KSac4534u6UslNl9ppcSwJlLEE9+YRFG0\nV4kfuG8rsV3mSv3fjizs3IaqbqASv7F4TNJnkn4vaZ6kyw5yDhC7KIreUmKdXi7pEyWah+ujKPpD\nrIUBZRsiaamkuUo8cDhB0jlRFG2Ktaocw85tAAAACAJPfAEAABAEGl8AAAAEgcYXAAAAQaDxBQAA\nQBCy8jqzvLw8JuiQttLS0qxvJ8raRSawdpGrsr12WbfIhIOtW574AgAAIAg0vgAAAAgCjS8AAACC\nQOMLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAhZ2bkNAICqpkGDBiYbMmSI\nyXr37u09v1evXiYrLCxMvzAAlYYnvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAgCb3UA\nAATpqquuMtnEiRNTPr9du3Yme++999KqCUDl4okvAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAsNt\nFdC+fXuTDR8+3Hts69atTVavXj2TjRo1ymSHHnqoyV5//XXvfbZv3+7NAQDSwIEDTTZp0iSTFRUV\nmeyBBx7wXnPZsmVp1wUgu3jiCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgpBXWlpa+TfJy6v8m1SS\nBg0amGzt2rUma9SoUTbK0ZdffunNfcN1L7zwQmWXk1WlpaV52b5nLq9dH9867dOnj/fYjh07mqxL\nly4m832PbNmyxWRHHXWU9z7r16832ZNPPmmyP/zhDybbu3ev95pVDWs3u3r16mWyuXPnmmznzp0m\nu/POO01Wnt3cqptsr92Q1y0y52Drlie+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy3laFhw4Ym\ne+2110y2efNm7/nLly83mW9oqGXLliY79thjTVa3bl3vfTZs2GCyM888M6XjcgUDQuVzzDHHmGze\nvHkm863HZLZt22Yy3xqvXbu2yXzfS5LUpEkTkzVt2tRkl19+ucnefvttk61bt857nzixditHfn6+\nN58+fbrJ+vfvb7JFixaZ7IILLki/sGqE4TbkIobbAAAAEDwaXwAAAASBxhcAAABBoPEFAABAEBhu\nq8KOOOIIk912223eY335oEGDTDZjxoz0C4sJA0Lls2zZMpO1b9/eZAsXLvSef+utt5rs66+/Nplv\n57XyOPLII032+uuvm8w5Z7IRI0aY7NFHH02rnsrA2q0co0eP9ubjx4832TPPPGOywYMHm6y4uDj9\nwqoRhtvS06xZM5Ndf/313mN9eVFRkcl8u8fee++9JvP9HSBJBQUF3rw6YbgNAAAAwaPxBQAAQBBo\nfAEAABAEGl8AAAAEgcYXAAAAQeCtDjmmV69e3ty3Fe3DDz9ssptvvjnjNWULk/HJ+SaHv/zyS5PN\nmjXLZFdccYX3mnv37k2/sAp69tlnTdavXz+TderUyWQffvhhpdSUDtZu+jp37myyJUuWeI9ds2aN\nydq1a2eyONd4ruCtDqk7/vjjTTZ16lSTXXjhhdkoR7t37/bmZ511lsmSvQEiV/FWBwAAAASPxhcA\nAABBoPEFAABAEGh8AQAAEIRacReA5A477DCTjRo1KuXzjz766EyWgyqsQ4cOJsvLs5/t/+qrr0wW\n94DPGWecYbL+/fubbPHixSbz/XNXxeE2lE+NGvaZjG976vz8fO/5r7zyisniXueoXpo3b26ylStX\nmqxWLdtmTZw40XvNRx55JKX7nHjiiSb73e9+Z7JGjRp57+Mbcvb9HPZtUV8d8MQXAAAAQaDxBQAA\nQBBofAEAABAEGl8AAAAEgZ3bqoj27dubbPbs2SZr06aN9/zPPvvMZL7dYQoKCipQXdXA7lflU1JS\nYrKNGzea7LTTTvOev3bt2ozW07BhQ2++dOlSk61evdpkvh3mfDslffzxxxWornKxdssn1Z0Ik7np\npptMNmXKlLRqChU7t/lNnjzZZMOGDTPZ1VdfbbKnnnoq4/XccMMNJps0aZL32Jo1a5ps1apVJvMN\nvG3btq0C1WUfO7cBAAAgeDS+AAAACAKNLwAAAIJA4wsAAIAgMNwWg6uuuspkd999t8mOPfZYk+3a\ntct7zZ49e5rMt9NVLmNAqHzGjh1rsjvuuMNkURR5z+/atavJ0hmOXLBggTf/8Y9/bLJOnTqZzLcr\nUq5g7ZbPoEGDTPbHP/7RZAsXLvSe3717d5Oxc1vFhD7cdsghh3hz3wDu9OnTTebbcTBbkv1sP+GE\nE1I637fD3K233ppWTdnCcBsAAACCR+MLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCwFsdMqRBgwbe\n/Fe/+pXJxowZY7IaNey/g2zZssVkXbp08d7Ht91gdcNkfPnUqVPHZDNmzDBZ3759ved//vnnJjvn\nnHNMtm7dOpM99thjJrvmmmu897nttttM5psmzmWs3eRq1aplsk8//dRkLVu2NNlxxx3nvWZ5tjfG\nwYX+VodkW7r/7W9/M9mFF15osrfeeivjNaWqT58+3vzFF180ma8X3Lp1q8l8b4TYvHlzBaqrXLzV\nAQAAAMGj8QUAAEAQaHwBAAAQBBpfAAAABMFOFaBCnnzySW/+k5/8JKXzX3jhBZNNmjTJZCEMsSEz\nvvvuO5MNHTrUZE2aNPGe79tK+K9//avJZs+ebbIBAwaYbM6cOd77VLdBNpSPb7iydevWJrvuuutM\nFvcQW7du3UzWq1cvk73xxhsm823h7fueRbw6duyY8rHLly+vxErK77XXXvPmvsFl3/ecbz3u2LEj\n/cJixhNfAAAABIHGFwAAAEGg8QUAAEAQaHwBAAAQBIbbMsT3wfDymDp1qsmWLl2a1jWBA23fvt1k\nvXv39h47duxYk918880mGzFiREr3fuSRR1I6DmFp0aJFSsfl5+dXciXJDRw40Jv7dij07Zg4bNgw\nk/l2xZo3b573PoMHDy6jQlSWJUuWePOSkhKT/fnPfzZZz549Tebb7bIyOOe8uW+Ndu3a1WT16tUz\nWXUYwOSJLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACALDbRni24VHktq3b1/h830Db/fdd5/3/K++\n+iql+wAH2rZtmze/8847TXbhhRearG3btind54ILLvDmyYZHEIY2bdqkdFy2dq1s1KiRyR566CHv\nsb4hoeLiYpP5hp66dOliMt+OhxLDbXH6+OOPvfmrr75qMt+g8Keffmoy305+kn93y0WLFpmsefPm\nJvMNsvl2f5WkZs2amcy3bl966SXv+bmOJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAIeaWlpZV/\nk7y8yr9JzOrWrevNn3nmGZN16tTJZKnuXrR+/XpvPmjQIJO9+eabKV0zV5SWluZl+54hrN1kunfv\nbrK5c+earHbt2ildb8+ePd78+uuvN9n06dNTumauYO0mN3/+fJN17NjRZEcffXQ2yvHuTphsuM33\n833y5MkmW7t2rcl8A04nnXSS9z5x7lqX7bWbK+vW93f+hAkTTHbTTTeldZ8tW7aYrHHjxmld0+ey\nyy4zmW/YLlccbN3yxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBYMviDNm1a5c3v+KK\nK0xWq5b9rz3ZtrEHOuqoo7y5b9r+lltuMdnjjz+e0n2Ac88912S+t8D06dPHZL5JZN8Wn5J/a+6v\nv/7aZK+88or3fOS2008/3WTJ3gBS1fi2ij/mmGNM9vvf/95kp5xyismq25t4qjPf3/m+N4LMmjXL\nZL6+IJmmTZumdFxRUZHJfN9bknTccceZbOfOnSnXlOt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAILAlsVVxMknn2yyiRMnmsw3cJSMb5vMVq1alauuqoRtXyuHb+1J0vvvv28y3yCab6DDx7clpiT9\n8Y9/NFlenv2ful27dibzrfGqiLWbnG/wq2fPniarjC2LfevMt54ffPDBtO7j+3v2scceM9moUaO8\n52/fvj2t+6eDLYtz19NPP+3NfcN13bp1M9mCBQsyXlO2sGUxAAAAgkfjCwAAgCDQ+AIAACAINL4A\nAAAIAju3laFevXomq4wdTlasWGGyvn37muyJJ57wnt+7d2+TtWjRwmTNmjUz2bp161IpEdVUw4YN\nvblvh8EXXnihwveZPXu2N2/ZsqXJ7r//fpN16tTJZLky3IbyadSokcl8gzrPPPOM93zf2u3Xr5/J\nGjdubLLu3bunUqIkaceOHSZbsmSJyX7729+abPHixSnfB6hsrVu3jruErOGJLwAAAIJA4wsAAIAg\n0PgCAAAgCDS+AAAACALDbfvxfbjbN6gwf/58k61cudJ7Td/g2JAhQ0xWu3ZtkzVv3txkbdq08d7H\n54svvkipHoStQ4cO3nz9+vUm830/pGvKlCkmu/rqq012ww03mGzu3LkZrwfZtXz5cpMNHTrUZL7d\npnxZurZt22ayZIOZ99xzj8n+/e9/Z7wmoCIKCwvjLqFK4okvAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAsNt+7nssstMdtRRR5ls8ODBGb93Xl6eyUpLS1M+3/ch9mHDhqVVE8Lg2+FPkv7+979n5f57\n9uwx2TfffGOyH/3oRybz7by1ZcuWzBSGrHjuuedM5tu1cvXq1SarWbOm95rJ8gM9++yzJluzZo3J\nfIPCQFX39ttve/Nrr73WZE2aNKnscqoMnvgCAAAgCDS+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAg\nCLzVYT+HH3543CX8f+bMmWOy8ePHe4/duHGjyXxbzgIHSvb2kC5dupisX79+Jlu0aJHJGjRoYLL8\n/HzvfU488USTnXrqqSZ79NFHTcYbHHLft99+a7Lzzz8/hkqA6qVGDf+zTd9bpHw9RHXFE18AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEhtv2M2rUKJMtXLjQZAMGDDDZ0Ucf7b2mb3DD55FHHjHZO++8\nY7Li4uKUrgek6tNPP/Xmvu2AfdvLbt682WTlGW7zDVq8++67Jhs7dqz3fACAVVJS4s2TDTSHgie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy37aeoqMhkb775ZkoZkKveeOMNbz5lyhST+XZz69Ch\nQ1r3Hz16tMmeeOIJk7FLGwBUjosuushkU6dOjaGSyscTXwAAAASBxhcAAABBoPEFAABAEGh8AQAA\nEASG24DAbdiwwZv/8pe/zHIlAIBMKSwsTPnYWrXCaQd54gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAIJA4wsAAIAg5JWWllb+TfLyKv8mqPZKS0vzsn1P1i4ygbWLXJXttcu6zZxGjRp5c9/277t27TJZ\n/fr1M15Tthxs3fLEFwAAAEGg8QUAAEAQaHwBAAAQBBpfAAAABIHhNuQMBoSQq1i7yFUMtyEXMdwG\nAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCkJXhNgAAACBuPPEFAABAEGh8AQAAEAQaXwAAAASBxhcA\nAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEIRacRcAyTl3pqQHJJ0i6RtJMySNjqKoJNbCgBQ45w6R9KmkoiiKWsVcDnBQ\nzrk1klp6/ujRKIpuzG41QGqcc40k3S2pj6SmkgokTZf0G3qF8qHxjZlzrq2kP0u6T9IASacqsZi3\nSZoQY2lAqu6RdKSkr+IuBEjRg0o8bNjfjjgKAVI0U1IrSQMl/UvSxZIelrRLifWMFNH4xu8OSa9H\nUXTPvv//X865rZK+jbEmICXOuc6Shkp6XtKPYy4HSFVhFEXr4y4CSIVz7lhJp0m6PIqit/bFU5xz\nvST1FY1vudD4xsg5V0NSD0lD9s+jKFoQT0VA6pxzNSX9r6TfSSoVjS8AZFwURQWSDkvyx8XZrKU6\noPGNVytJDSUVOudmSzpbUqGkh6MomhxnYUAKblRi/f5G0siYawGAIDjnaivx0cgfSeoXczk5h8Y3\nXkfu+zpZ0kNKNBAXS3rQOVc/iqLfxFYZcBDOueaSxkv6SRRFu51zcZcElEdn59wCSScr8dnepyVN\niKJod7xlAQfnnFsq6XRJX0vqF0XRSzGXlHN4nVm8au/7+mwURY9HUbQ8iqJ7lfi85C+dc3kx1gYc\nzMOSXo6iaGHchQDltElSPSU+F3mRpEmSblPiYztAVfczSZ0lTZU00zl3Rcz15Bye+MZr+76vyw7I\nlyjxa4ymkhjAQJXinOupxMdy2sVdC1BeURSdekC0Yt8r+e5xzo2Joug/cdQFpGLf530LJC13zjWQ\n9Ihz7nleaZY6Gt94fSGpRFLjA/L/Ponflt1ygJRcKulwSV/t9xGHGpLynHPFku6OoujuuIoDKuDD\nfV+bSaLxRZXinGspqYukmVEU7T/MtlKJobcm4iFZyvioQ4yiKCqU9K6k/zngj86S9EUURTuzXxVQ\npjFKfDayw37/97gS7/H9738GqhyX8JRz7vgD/ugUSXsl/TOGsoCynCDpGSV+07a/k5R4j++WrFeU\nw3jiG79xkhY450ZKmiWppxKf4bk+1qqAJKIo+lLSl/tnzrmNSuzctjKeqoCUFCjRPMx0zt2qxNPd\nH0v6taRpURRtjrM4IInFkj6Q9Hvn3I2SVks6R9J1kp6IomhPjLXlHJ74xmzfy6gvk3S5pE8k3Szp\n+iiK/hBrYQBQzez7Ldq5SnzMbKakVUr8BuN3SryeD6hyoijaq8RDsbclzVDiIw63KvHgbHiMpeWk\nvNLS0rhrAAAAACodT3wBAAAQBBpfAAAABIHGFwAAAEGg8QUAAEAQsvI6s7y8PCbokLbS0tKsb+HM\n2kUmsHaRq7K9dlm3yISDrVue+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0v\nAAAAglAr7gKqu+7du5ts+PDhJrvwwgtNVlpaarLVq1d77zNr1iyTTZ061WRfffWV93wAAIDqjie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCHm+AaqM3yQvr/JvErPrrrvOm0+cONFk+fn5lV2OJGnx\n4sUmGzBggMnWrVuXjXLSVlpampfte4awdlH5WLvIVdleu7m8bmfMmGGyK6+80mTz58/3nj9nzhyT\nLV261GQFBQUp1bNnzx5vvnfv3pTOz2UHW7c88QUAAEAQaHwBAAAQBBpfAAAABIHGFwAAAEFg57YK\n6NGjh8keeOAB77G+Qbbly5ebbMSIESb7+OOPU65pyJAhJhs3bpzJRo4cabKbbrop5fsgt9WvX99k\no0aN8h47ZswYk/mGYcePH2+y9u3bm6xXr16plAgAOWnVqlUmKykpMZmvhzhYXlHTp0/35tdee63J\niouLM3rvqownvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAjs3FaGnj17muz55583mW9oSJLmzZtn\nMt8ubxs2bKhAdf8nL89uUuIbeLvoootM9tOf/jSte2cLu1+lr0WLFib797//7T22U6dOJlu2bJnJ\nfMNtv/jFL0zmnPPeJ921nwtYuzhQ06ZNTdamTRvvsXXq1DFZ//79Tfbss8+aLNnuXe+++25ZJUpi\n57Z0+XqIrl27pnz+qaeeajLfz/G6deua7NBDD/Ve8/zzzzeZb6fXXMbObQAAAAgejS8AAACCQOML\nAACAIND4AgAAIAjs3LafWrXsfx2+3c98g2wrVqzwXtO3Q8qmTZsqUN3B+YYUp02bZrK5c+dm/N7I\nHa1atcr4NYuKikzmG6po27at9/wQhtsQjh/84Acm+9nPfmaywYMHm6xZs2bea6Y6hD5o0KCUjpOk\nmjVrpnwsKu7VV19NKUtX9+7dTTZ//nzvsRdffLHJqttw28HwxBcAAABBoPEFAABAEGh8AQAAEAQa\nXwAAAASBxhcAAABB4K0O+7n66qtN1rFjR5Pt3r3bZAMHDvReszLe4JCOzZs3x10CYnTmmWdm/Jov\nvfSSyXxvQ+ncubP3/JCmiZGbOnTo4M2HDx9usgsuuMBkRx11VMZr8tm+fbvJFi1alJV7I3saN25s\nsrvuustkxcXF3vOTve0hFDzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+/+MUvUjpu2LBh\nJvvwww8zXQ6QFt+WpJdeeqnJSkpKvOcnG4wAKsK3Jbwk1alTx2SFhYWVXY4k/8Dl9OnTTda6dWvv\n+d/73vcyXpPPJ598YrIxY8aYzDe8vGTJkkqpCRXXsGFDb96lSxeT5efnm2z06NEm863lp556ynuf\nv/zlL2VUWL3xxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASB4bYK+M9//hN3CUCZmjZtarJTTz3V\nZP/617+8569YsSKl+xQVFZls7969JmvTpk1K10P15NtZSpIuueQSk82ZM8dkY8eOTfleJ598sslu\nv/12k/mGPWvXrm2yvLw8731KS0tTrikVvn9uSfr5z39usl27dmX03khfgwYNTDZhwgST+dadlN4O\nf++9957J7rvvvgpfrzrjiS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAhCkMNtvsEHSTrhhBNMtn37\ndpNFUZTxmoC4rF69Oq3zP//8c5MVFBSYrEOHDmndB7njkEMOMdmVV17pPbZFixYma9euncl8g0PO\nOe81e/ToUVaJ5ZJsuM3Ht3va008/bbIXX3zRZOyyltvOOussk91www1Zubfv+yPZrpyh44kvAAAA\ngkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAghDkWx1q1fL/Y9esWdNkO3fuNBlbFiMXnHfeeSkd\nN3HixLTu4/t+8n0vNWvWzHu+7w0A27ZtS6smxKtx48Ymq1+/vvfYVLf9HT58uMkqYyvh999/32Qz\nZ870Hvvaa6+ZrLCw0GRffvllhetB7ujSpUta52/cuNFkU6dONVmNGvaZ5R133GEy33bJkjR06FCT\nffPNN6mUWC3wxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2uB1++OEm69mzp8luvfXWlK+5\nZs0ak7Vq1cpk69evN9kLL7xgsunTp3vvU1RUlHJNiNcPf/hDk23YsMFk77zzTlr38Q2Azp8/32TD\nhg3znn/ooYeajOG23Ob7ebRp0ybvsb5BuGwZP368yR5++GGTbdmyJRvlIMeNGzfOZP/4xz9MtmPH\nDu/5f/3rX022Z88ek/mGOmfPnm2yt956y3ufadOmmWzIkCEm27p1q/f8XMcTXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEASG28rgG7zo3LmzyT744APv+W3atDHZwoULTdaiRQuT7dq1y2QfffSR9z6+\nYRJfNmjpU0jmAAAIzUlEQVTQIJNdcMEFJuvatav3Ppdeeqk3R7x8u2JdfPHFJvMNSiQbtEhHdR2K\nQMUlG7RxzlX4mm+//bY3nzNnjsmee+45k/l2qyopKalwPQhbcXGxyebNm5fx+/h2Jly5cqXJrr76\nau/5c+fONdnixYtNNmXKlApUV/XxxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2ZLvwfPvt\ntybz7Srly44//njvNRctWmSyY445xmS+wY8bbrjBZJ999pn3Pql6+eWXTeb7oPuJJ56Y1n2QXfXq\n1TNZy5YtTVZQUJCNcrzfS8n4vp+yVSeyZ+TIkd7ct2ulb9jX55xzzkmnJKBa8/19L0l/+tOfTOb7\n/pw5c6bJku3AmEt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAAAIIQ5HCbb0czSVq3bp3JfIM3l19+\nucnatm3rvaZvkM23c1ufPn1MVhk7avnuPW3aNJNddNFFGb834pefn2+yTp06eY/97rvvTOYbDK1b\nt67JfDsLJTN16lSTnXfeeSYrKipK+ZqoegoLC725b9BmwIABJmvevLnJ1q9f773m7NmzTXbXXXeZ\nLNmgM1CdTZ482WT9+/c32TXXXGOye++9t1Jqyiae+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ\n+AIAACAIeeWZvq7wTfLyKv8mGTBhwgST3X777Wld0/fGhJtvvtlkO3fuTOs+6XjuuedM1q1bN++x\nHTp0MNnatWszXpNPaWlpXlZutJ9cWbtHHnmkyTZu3JjWNYuLi03mm8z3vSnCt4VyefjecjJv3ry0\nrhkn1m75+CbMH3/8cZM1bNjQe77v77WlS5earFevXib75ptvUikxGNleu7m8bnNFnTp1TPbuu++a\nbMWKFSYbNGhQpdSUaQdbtzzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+NGjUy2Ycffmiy\nFi1apHzNW265xWSTJk0qX2GVzLftZ7KhkVNOOcVkURRlvCYfBoSSq1mzpsnGjx9vspEjR2ajnHL5\n4IMPTHbGGWeYbO/evdkop1KwdtPn+7nrGx6WpPPPPz+la37yyScmu+yyy0y2atWqlK5XHYU+3Obb\nKlvyD1v27dvXZLt37854TZVhzJgxJrv22mtNdtJJJ5ls69atlVJTOhhuAwAAQPBofAEAABAEGl8A\nAAAEgcYXAAAAQWC4rQw9evQw2Z/+9CeT1a9f33v+jh07TPbqq6+a7N577zXZypUrUymxXLp3726y\nl19+2WSfffaZ9/x27dplvKZUMSBUPr6BtyZNmpgs2dr1rRXfMJAv8w1AvPnmm977+HbUOuuss7zH\n5irWbuXwDUFK/l3+fLsb+rz//vsmu/HGG73H+gYzq5vQh9tatWrlzf/5z3+a7OmnnzbZr3/9a5Nt\n2LAh7boyzTfcdvfdd5vs+OOPN9maNWsqo6S0MNwGAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCwHBb\nBXTt2tVk999/v/fYk08+OaVr7tq1y2RDhw412dq1a73n+z5c3qVLF5NNnjzZZL4d655//nnvfQYN\nGuTNs4EBodzRqVMnkyUbBGK4rXKEvHYvueQSk82ZM6fC1/P9LJak6dOnV/iauSL04bajjz7am/t2\nLPUNCq9evdpkw4YN817znXfeMVlxcXFZJZZbnz59TPbAAw+YLD8/32Q/+MEPTPbtt99mprAMYrgN\nAAAAwaPxBQAAQBBofAEAABAEGl8AAAAEgeG2DEm2K9DgwYNN5tvJ5bDDDst4TT6+D8r7do0bN25c\nNsopFwaEcscRRxxhslWrVnmP3bt3r8m+//3vm6wqDlCkirVbOa677jpv/uijj2b0Pk8++aQ39/18\nr25CH25Lpm/fviabNWtWWtf07ejm69Feeuklk/Xu3Tvl+zRu3NhkvkG2e+65x2R33nlnyveJE8Nt\nAAAACB6NLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACAJvdYiBb6LSN53smxpt3759yvcpKCgw2eOP\nP26yCRMmpHzNODEZn9t8WxNL0plnnmky3zah69aty3hN2cLaLR/ftvAjR4402dlnn+09P9N/r914\n443efOrUqRm9T1XEWx38atasabJu3bqZbMSIESZLd0v2vDz7P0m6a37atGkmGz16tMk2bdqU1n2y\nhbc6AAAAIHg0vgAAAAgCjS8AAACCQOMLAACAIDDchpzBgFBuGz58uDd/6KGHTHbJJZeYzLdNZ65g\n7Urdu3f35tdcc43JfENCvi1VfUM+UuqDPuPHjzfZsmXLTPbyyy+ndL3qiOG29NSoYZ8vnnbaad5j\nfQPtP/zhD012xhlnmGzPnj0mmz17tvc+kydPNplv3ZeUlHjPzwUMtwEAACB4NL4AAAAIAo0vAAAA\ngkDjCwAAgCAw3IacwYBQbjv99NO9+d/+9jeT/eUvfzHZueeem+mSsia0tTt06FCTJdsh0reTpc/W\nrVtNtmTJEu+xH330kclefPFFk61YscJkuTzQUxkYbkMuYrgNAAAAwaPxBQAAQBBofAEAABAEGl8A\nAAAEgeE25IzQBoRQfYS2dn27TfXo0cN77Pz581O65saNG032+eefl68wlBvDbchFDLcBAAAgeDS+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCLzVATkjtMl4VB+sXeQq3uqAXMRbHQAAABA8Gl8AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQaDxBQAAQBBofAEAABCErOzcBgAAAMSN\nJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCD8P5Eh05z4VcPCAAAAAElF\nTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plots(x_imgs[:8], titles=y_valid[:8])" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "## The Most Important Machine Learning Concepts" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "### Functions, parameters, and training" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "A **function** takes inputs and returns outputs. For instance, $f(x) = 3x + 5$ is an example of a function. If we input $2$, the output is $3\\times 2 + 5 = 11$, or if we input $-1$, the output is $3\\times -1 + 5 = 2$\n", "\n", "Functions have **parameters**. The above function $f$ is $ax + b$, with parameters a and b set to $a=3$ and $b=5$.\n", "\n", "Machine learning is often about learning the best values for those parameters. For instance, suppose we have the data points on the chart below. What values should we choose for $a$ and $b$?" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "\"\"" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "In the above gif fast.ai Practical Deep Learning for Coders course, [intro to SGD notebook](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/sgd-intro.ipynb)), an algorithm called stochastic gradient descent is being used to learn the best parameters to fit the line to the data (note: in the gif, the algorithm is stopping before the absolute best parameters are found). This process is called **training** or **fitting**.\n", "\n", "Most datasets will not be well-represented by a line. We could use a more complicated function, such as $g(x) = ax^2 + bx + c + \\sin d$. Now we have 4 parameters to learn: $a$, $b$, $c$, and $d$. This function is more flexible than $f(x) = ax + b$ and will be able to accurately model more datasets.\n", "\n", "Neural networks take this to an extreme, and are infinitely flexible. They often have thousands, or even hundreds of thousands of parameters. However the core idea is the same as above. The neural network is a function, and we will learn the best parameters for modeling our data." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "### Training & Validation data sets" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Possibly **the most important idea** in machine learning is that of having separate training & validation data sets.\n", "\n", "As motivation, suppose you don't divide up your data, but instead use all of it. And suppose you have lots of parameters:" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "This is called over-fitting. A validation set helps prevent this problem." ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "\"\"\n", "
\n", "[Underfitting and Overfitting](https://datascience.stackexchange.com/questions/361/when-is-a-model-underfitted)\n", "
" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The error for the pictured data points is lowest for the model on the far right (the blue curve passes through the red points almost perfectly), yet it's not the best choice. Why is that? If you were to gather some new data points, they most likely would not be on that curve in the graph on the right, but would be closer to the curve in the middle graph.\n", "\n", "This illustrates how using all our data can lead to **overfitting**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural Net (with nn.torch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Imports " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from fastai.metrics import *\n", "from fastai.model import *\n", "from fastai.dataset import *\n", "from fastai.core import *\n", "\n", "import torch.nn as nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Neural networks " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use fastai's ImageClassifierData, which holds our training and validation sets and will provide batches of that data in a form ready for use by a PyTorch model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "md = ImageClassifierData.from_arrays(path, (x,y), (x_valid, y_valid))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will begin with the highest level abstraction: using a neural net defined by PyTorch's Sequential class. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net = nn.Sequential(\n", " nn.Linear(28*28, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, 10)\n", ").cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each input is a vector of size $28\\times 28$ pixels and our output is of size $10$ (since there are 10 digits: 0, 1, ..., 9). \n", "\n", "We use the output of the final layer to generate our predictions. Often for classification problems (like MNIST digit classification), the final layer has the same number of outputs as there are classes. In that case, this is 10: one for each digit from 0 to 9. These can be converted to comparative probabilities. For instance, it may be determined that a particular hand-written image is 80% likely to be a 4, 18% likely to be a 9, and 2% likely to be a 3. In our case, we are not interested in viewing the probabilites, and just want to see what the most likely guess is." ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Layers" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Sequential defines layers of our network, so let's talk about layers. Neural networks consist of **linear layers alternating with non-linear layers**. This creates functions which are incredibly flexible. Deeper layers are able to capture more complex patterns." ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "hidden": true }, "source": [ "Layer 1 of a convolutional neural network:\n", "\"pytorch\"\n", "
\n", "[Matthew Zeiler and Rob Fergus](http://www.matthewzeiler.com/wp-content/uploads/2017/07/arxive2013.pdf)\n", "
" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "hidden": true }, "source": [ "Layer 2:\n", "\"pytorch\"\n", "
\n", "[Matthew Zeiler and Rob Fergus](http://www.matthewzeiler.com/wp-content/uploads/2017/07/arxive2013.pdf)\n", "
" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "hidden": true }, "source": [ "Deeper layers can learn about more complicated shapes (although we are only using 2 layers in our network):\n", "\"pytorch\"\n", "
\n", "[Matthew Zeiler and Rob Fergus](http://www.matthewzeiler.com/wp-content/uploads/2017/07/arxive2013.pdf)\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training the network " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will set a few inputs for our *fit* method:\n", "- **Optimizer**: algorithm for finding the minimum. typically these are variations on *stochastic gradient descent*, involve taking a step that appears to be the right direction based on the change in the function.\n", "- **Loss**: what function is the optimizer trying to minimize? We need to say how we're defining the error.\n", "- **Metrics**: other calculations you want printed out as you train" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "loss=F.cross_entropy\n", "metrics=[accuracy]\n", "opt=optim.Adam(net.parameters())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Fitting* is the process by which the neural net learns the best parameters for the dataset." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e78d3f5cd8f94049afdc813f1a1ff0dc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.1446 0.1427 0.9583] \n", "\n" ] } ], "source": [ "fit(net, md, n_epochs=1, crit=loss, opt=opt, metrics=metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GPUs are great at handling lots of data at once (otherwise don't get performance benefit). We break the data up into **batches**, and that specifies how many samples from our dataset we want to send to the GPU at a time. The fastai library defaults to a batch size of 64. On each iteration of the training loop, the error on 1 batch of data will be calculated, and the optimizer will update the parameters based on that.\n", "\n", "An **epoch** is completed once each data sample has been used once in the training loop." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have the parameters for our model, we can make predictions on our validation set." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": true }, "outputs": [], "source": [ "preds = predict(net, md.val_dl)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "collapsed": true, "scrolled": false }, "outputs": [], "source": [ "preds = np.argmax(preds, axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how some of our preditions look!" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAr4AAAF0CAYAAADFHDo6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl0ldW9//FPGFJGRVQQUUDBbhdUAcGpUusMCBekYguK\nlUnFoVa0ViYVRIu2KqAo3paKOBYQwQEHitAqsmq1oCxUHtGWkiqTIEIAISH5/XHouvz47kNOck7O\nk5P9fq11V3o/PMPXdid8fXK+z84rLS0VAAAAUN3ViLsAAAAAIBtofAEAABAEGl8AAAAEgcYXAAAA\nQaDxBQAAQBBofAEAABCEWnEXEDLnXCNJd0vqI6mppAJJ0yX9JoqikjhrAw7GOZcvaYSkyyW1krRJ\n0jRJ90VRtDvG0oAyOefOlPSApFMkfSNphqTR/NxFVeacqyXpLklXKdEzfCppZBRFr8daWI7hiW+8\nZkrqKmmgpBMlTVSiER4eY01AKu6TdIukUZLaSvqlEut2QpxFAWVxzrWV9GdJryuxdm+WdJOk2+Os\nC0jBQ5JulTROibX7hqSXnHMdY60qx+SxgUU8nHPHSloh6fL9/23NObdAUsMois6MrTigDM65TZKe\njaLo5v2yiUqs56bxVQYcnHPueUm1oii6bL/sIknfRlH0XnyVAck55+pK2ippYhRFI/bLl0gqiKKo\nf2zF5Rg+6hCTKIoKJB2W5I+Ls1kLUAGlsut0974cqJKcczUk9ZA0ZP88iqIF8VQEpKyNpHxJ7xyQ\nvyLpV9kvJ3fR+FYRzrnakgZI+pGkfjGXA5TlMUnDnHMzJX2gxK/dBkj631irAg6ulaSGkgqdc7Ml\nnS2pUNLDURRNjrMwoAw193098IHDJklHOOcOiaJoW5Zrykl8xrcKcM4tlfSdEp+b7BdF0UsxlwSU\nZZykOZL+rsST3pWSXpU0NsaagLIcue/rZElvSeom6QlJDzrnRsVWFVC2zyXtldTpgLz9vq8Ns1tO\n7qLxrRp+JqmzpKmSZjrnroi5HqAstyuxbgdKOlXSlUq8neTuGGsCylJ739dnoyh6PIqi5VEU3Svp\neUm/dM7lxVgbkFQURYWSnpF0i3PuLOdcTedcb0k/3XdIUXzV5RY+6lAF7Pu8b4Gk5c65BpIecc49\nz6t1UBU55xpr39tHoiiasS/+yDlXR9JU59zkKIq+jq9CIKnt+74uOyBfosRHdZpKWp/VioDU3STp\nECXWa4mkdyXdqcRHz76Jsa6cwhPfmDjnWjrnrtj3Xr79rVRi6K1JDGUBqWitxJOzVQfknyvxL9PH\nZb0iIDVfKNEwND4g/+/fhXxGElVWFEXboij6iaQjJDWNouhsSfUkrYqiiCe+KaLxjc8JSvza4uwD\n8pMk7ZK0JesVAan5z76v3z8gP/GAPweqlH2/Ln5X0v8c8EdnSfoiiqKd2a8KSI1zro9z7rQoijbv\n91u1fpKYCyoHPuoQn8VKTMP/3jl3o6TVks6RdJ2kJ6Io2hNjbUBSURStc87NkXSnc26dEu+jbivp\nDkkLoihaF2uBwMGNk7TAOTdS0ixJPZX4vPr1sVYFlO1KSac4534u6UslNl9ppcSwJlLEE9+YRFG0\nV4kfuG8rsV3mSv3fjizs3IaqbqASv7F4TNJnkn4vaZ6kyw5yDhC7KIreUmKdXi7pEyWah+ujKPpD\nrIUBZRsiaamkuUo8cDhB0jlRFG2Ktaocw85tAAAACAJPfAEAABAEGl8AAAAEgcYXAAAAQaDxBQAA\nQBCy8jqzvLw8JuiQttLS0qxvJ8raRSawdpGrsr12WbfIhIOtW574AgAAIAg0vgAAAAgCjS8AAACC\nQOMLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAhZ2bkNAICqpkGDBiYbMmSI\nyXr37u09v1evXiYrLCxMvzAAlYYnvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAgCb3UA\nAATpqquuMtnEiRNTPr9du3Yme++999KqCUDl4okvAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAsNt\nFdC+fXuTDR8+3Hts69atTVavXj2TjRo1ymSHHnqoyV5//XXvfbZv3+7NAQDSwIEDTTZp0iSTFRUV\nmeyBBx7wXnPZsmVp1wUgu3jiCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgpBXWlpa+TfJy6v8m1SS\nBg0amGzt2rUma9SoUTbK0ZdffunNfcN1L7zwQmWXk1WlpaV52b5nLq9dH9867dOnj/fYjh07mqxL\nly4m832PbNmyxWRHHXWU9z7r16832ZNPPmmyP/zhDybbu3ev95pVDWs3u3r16mWyuXPnmmznzp0m\nu/POO01Wnt3cqptsr92Q1y0y52Drlie+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy3laFhw4Ym\ne+2110y2efNm7/nLly83mW9oqGXLliY79thjTVa3bl3vfTZs2GCyM888M6XjcgUDQuVzzDHHmGze\nvHkm863HZLZt22Yy3xqvXbu2yXzfS5LUpEkTkzVt2tRkl19+ucnefvttk61bt857nzixditHfn6+\nN58+fbrJ+vfvb7JFixaZ7IILLki/sGqE4TbkIobbAAAAEDwaXwAAAASBxhcAAABBoPEFAABAEBhu\nq8KOOOIIk912223eY335oEGDTDZjxoz0C4sJA0Lls2zZMpO1b9/eZAsXLvSef+utt5rs66+/Nplv\n57XyOPLII032+uuvm8w5Z7IRI0aY7NFHH02rnsrA2q0co0eP9ubjx4832TPPPGOywYMHm6y4uDj9\nwqoRhtvS06xZM5Ndf/313mN9eVFRkcl8u8fee++9JvP9HSBJBQUF3rw6YbgNAAAAwaPxBQAAQBBo\nfAEAABAEGl8AAAAEgcYXAAAAQeCtDjmmV69e3ty3Fe3DDz9ssptvvjnjNWULk/HJ+SaHv/zyS5PN\nmjXLZFdccYX3mnv37k2/sAp69tlnTdavXz+TderUyWQffvhhpdSUDtZu+jp37myyJUuWeI9ds2aN\nydq1a2eyONd4ruCtDqk7/vjjTTZ16lSTXXjhhdkoR7t37/bmZ511lsmSvQEiV/FWBwAAAASPxhcA\nAABBoPEFAABAEGh8AQAAEIRacReA5A477DCTjRo1KuXzjz766EyWgyqsQ4cOJsvLs5/t/+qrr0wW\n94DPGWecYbL+/fubbPHixSbz/XNXxeE2lE+NGvaZjG976vz8fO/5r7zyisniXueoXpo3b26ylStX\nmqxWLdtmTZw40XvNRx55JKX7nHjiiSb73e9+Z7JGjRp57+Mbcvb9HPZtUV8d8MQXAAAAQaDxBQAA\nQBBofAEAABAEGl8AAAAEgZ3bqoj27dubbPbs2SZr06aN9/zPPvvMZL7dYQoKCipQXdXA7lflU1JS\nYrKNGzea7LTTTvOev3bt2ozW07BhQ2++dOlSk61evdpkvh3mfDslffzxxxWornKxdssn1Z0Ik7np\npptMNmXKlLRqChU7t/lNnjzZZMOGDTPZ1VdfbbKnnnoq4/XccMMNJps0aZL32Jo1a5ps1apVJvMN\nvG3btq0C1WUfO7cBAAAgeDS+AAAACAKNLwAAAIJA4wsAAIAgMNwWg6uuuspkd999t8mOPfZYk+3a\ntct7zZ49e5rMt9NVLmNAqHzGjh1rsjvuuMNkURR5z+/atavJ0hmOXLBggTf/8Y9/bLJOnTqZzLcr\nUq5g7ZbPoEGDTPbHP/7RZAsXLvSe3717d5Oxc1vFhD7cdsghh3hz3wDu9OnTTebbcTBbkv1sP+GE\nE1I637fD3K233ppWTdnCcBsAAACCR+MLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCwFsdMqRBgwbe\n/Fe/+pXJxowZY7IaNey/g2zZssVkXbp08d7Ht91gdcNkfPnUqVPHZDNmzDBZ3759ved//vnnJjvn\nnHNMtm7dOpM99thjJrvmmmu897nttttM5psmzmWs3eRq1aplsk8//dRkLVu2NNlxxx3nvWZ5tjfG\nwYX+VodkW7r/7W9/M9mFF15osrfeeivjNaWqT58+3vzFF180ma8X3Lp1q8l8b4TYvHlzBaqrXLzV\nAQAAAMGj8QUAAEAQaHwBAAAQBBpfAAAABMFOFaBCnnzySW/+k5/8JKXzX3jhBZNNmjTJZCEMsSEz\nvvvuO5MNHTrUZE2aNPGe79tK+K9//avJZs+ebbIBAwaYbM6cOd77VLdBNpSPb7iydevWJrvuuutM\nFvcQW7du3UzWq1cvk73xxhsm823h7fueRbw6duyY8rHLly+vxErK77XXXvPmvsFl3/ecbz3u2LEj\n/cJixhNfAAAABIHGFwAAAEGg8QUAAEAQaHwBAAAQBIbbMsT3wfDymDp1qsmWLl2a1jWBA23fvt1k\nvXv39h47duxYk918880mGzFiREr3fuSRR1I6DmFp0aJFSsfl5+dXciXJDRw40Jv7dij07Zg4bNgw\nk/l2xZo3b573PoMHDy6jQlSWJUuWePOSkhKT/fnPfzZZz549Tebb7bIyOOe8uW+Ndu3a1WT16tUz\nWXUYwOSJLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACALDbRni24VHktq3b1/h830Db/fdd5/3/K++\n+iql+wAH2rZtmze/8847TXbhhRearG3btind54ILLvDmyYZHEIY2bdqkdFy2dq1s1KiRyR566CHv\nsb4hoeLiYpP5hp66dOliMt+OhxLDbXH6+OOPvfmrr75qMt+g8Keffmoy305+kn93y0WLFpmsefPm\nJvMNsvl2f5WkZs2amcy3bl966SXv+bmOJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAIeaWlpZV/\nk7y8yr9JzOrWrevNn3nmGZN16tTJZKnuXrR+/XpvPmjQIJO9+eabKV0zV5SWluZl+54hrN1kunfv\nbrK5c+earHbt2ildb8+ePd78+uuvN9n06dNTumauYO0mN3/+fJN17NjRZEcffXQ2yvHuTphsuM33\n833y5MkmW7t2rcl8A04nnXSS9z5x7lqX7bWbK+vW93f+hAkTTHbTTTeldZ8tW7aYrHHjxmld0+ey\nyy4zmW/YLlccbN3yxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBYMviDNm1a5c3v+KK\nK0xWq5b9rz3ZtrEHOuqoo7y5b9r+lltuMdnjjz+e0n2Ac88912S+t8D06dPHZL5JZN8Wn5J/a+6v\nv/7aZK+88or3fOS2008/3WTJ3gBS1fi2ij/mmGNM9vvf/95kp5xyismq25t4qjPf3/m+N4LMmjXL\nZL6+IJmmTZumdFxRUZHJfN9bknTccceZbOfOnSnXlOt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAILAlsVVxMknn2yyiRMnmsw3cJSMb5vMVq1alauuqoRtXyuHb+1J0vvvv28y3yCab6DDx7clpiT9\n8Y9/NFlenv2ful27dibzrfGqiLWbnG/wq2fPniarjC2LfevMt54ffPDBtO7j+3v2scceM9moUaO8\n52/fvj2t+6eDLYtz19NPP+3NfcN13bp1M9mCBQsyXlO2sGUxAAAAgkfjCwAAgCDQ+AIAACAINL4A\nAAAIAju3laFevXomq4wdTlasWGGyvn37muyJJ57wnt+7d2+TtWjRwmTNmjUz2bp161IpEdVUw4YN\nvblvh8EXXnihwveZPXu2N2/ZsqXJ7r//fpN16tTJZLky3IbyadSokcl8gzrPPPOM93zf2u3Xr5/J\nGjdubLLu3bunUqIkaceOHSZbsmSJyX7729+abPHixSnfB6hsrVu3jruErOGJLwAAAIJA4wsAAIAg\n0PgCAAAgCDS+AAAACALDbfvxfbjbN6gwf/58k61cudJ7Td/g2JAhQ0xWu3ZtkzVv3txkbdq08d7H\n54svvkipHoStQ4cO3nz9+vUm830/pGvKlCkmu/rqq012ww03mGzu3LkZrwfZtXz5cpMNHTrUZL7d\npnxZurZt22ayZIOZ99xzj8n+/e9/Z7wmoCIKCwvjLqFK4okvAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAsNt+7nssstMdtRRR5ls8ODBGb93Xl6eyUpLS1M+3/ch9mHDhqVVE8Lg2+FPkv7+979n5f57\n9uwx2TfffGOyH/3oRybz7by1ZcuWzBSGrHjuuedM5tu1cvXq1SarWbOm95rJ8gM9++yzJluzZo3J\nfIPCQFX39ttve/Nrr73WZE2aNKnscqoMnvgCAAAgCDS+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAg\nCLzVYT+HH3543CX8f+bMmWOy8ePHe4/duHGjyXxbzgIHSvb2kC5dupisX79+Jlu0aJHJGjRoYLL8\n/HzvfU488USTnXrqqSZ79NFHTcYbHHLft99+a7Lzzz8/hkqA6qVGDf+zTd9bpHw9RHXFE18AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEhtv2M2rUKJMtXLjQZAMGDDDZ0Ucf7b2mb3DD55FHHjHZO++8\nY7Li4uKUrgek6tNPP/Xmvu2AfdvLbt682WTlGW7zDVq8++67Jhs7dqz3fACAVVJS4s2TDTSHgie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy37aeoqMhkb775ZkoZkKveeOMNbz5lyhST+XZz69Ch\nQ1r3Hz16tMmeeOIJk7FLGwBUjosuushkU6dOjaGSyscTXwAAAASBxhcAAABBoPEFAABAEGh8AQAA\nEASG24DAbdiwwZv/8pe/zHIlAIBMKSwsTPnYWrXCaQd54gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAIJA4wsAAIAg5JWWllb+TfLyKv8mqPZKS0vzsn1P1i4ygbWLXJXttcu6zZxGjRp5c9/277t27TJZ\n/fr1M15Tthxs3fLEFwAAAEGg8QUAAEAQaHwBAAAQBBpfAAAABIHhNuQMBoSQq1i7yFUMtyEXMdwG\nAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCkJXhNgAAACBuPPEFAABAEGh8AQAAEAQaXwAAAASBxhcA\nAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEIRacRcAyTl3pqQHJJ0i6RtJMySNjqKoJNbCgBQ45w6R9KmkoiiKWsVcDnBQ\nzrk1klp6/ujRKIpuzG41QGqcc40k3S2pj6SmkgokTZf0G3qF8qHxjZlzrq2kP0u6T9IASacqsZi3\nSZoQY2lAqu6RdKSkr+IuBEjRg0o8bNjfjjgKAVI0U1IrSQMl/UvSxZIelrRLifWMFNH4xu8OSa9H\nUXTPvv//X865rZK+jbEmICXOuc6Shkp6XtKPYy4HSFVhFEXr4y4CSIVz7lhJp0m6PIqit/bFU5xz\nvST1FY1vudD4xsg5V0NSD0lD9s+jKFoQT0VA6pxzNSX9r6TfSSoVjS8AZFwURQWSDkvyx8XZrKU6\noPGNVytJDSUVOudmSzpbUqGkh6MomhxnYUAKblRi/f5G0siYawGAIDjnaivx0cgfSeoXczk5h8Y3\nXkfu+zpZ0kNKNBAXS3rQOVc/iqLfxFYZcBDOueaSxkv6SRRFu51zcZcElEdn59wCSScr8dnepyVN\niKJod7xlAQfnnFsq6XRJX0vqF0XRSzGXlHN4nVm8au/7+mwURY9HUbQ8iqJ7lfi85C+dc3kx1gYc\nzMOSXo6iaGHchQDltElSPSU+F3mRpEmSblPiYztAVfczSZ0lTZU00zl3Rcz15Bye+MZr+76vyw7I\nlyjxa4ymkhjAQJXinOupxMdy2sVdC1BeURSdekC0Yt8r+e5xzo2Joug/cdQFpGLf530LJC13zjWQ\n9Ihz7nleaZY6Gt94fSGpRFLjA/L/Ponflt1ygJRcKulwSV/t9xGHGpLynHPFku6OoujuuIoDKuDD\nfV+bSaLxRZXinGspqYukmVEU7T/MtlKJobcm4iFZyvioQ4yiKCqU9K6k/zngj86S9EUURTuzXxVQ\npjFKfDayw37/97gS7/H9738GqhyX8JRz7vgD/ugUSXsl/TOGsoCynCDpGSV+07a/k5R4j++WrFeU\nw3jiG79xkhY450ZKmiWppxKf4bk+1qqAJKIo+lLSl/tnzrmNSuzctjKeqoCUFCjRPMx0zt2qxNPd\nH0v6taRpURRtjrM4IInFkj6Q9Hvn3I2SVks6R9J1kp6IomhPjLXlHJ74xmzfy6gvk3S5pE8k3Szp\n+iiK/hBrYQBQzez7Ldq5SnzMbKakVUr8BuN3SryeD6hyoijaq8RDsbclzVDiIw63KvHgbHiMpeWk\nvNLS0rhrAAAAACodT3wBAAAQBBpfAAAABIHGFwAAAEGg8QUAAEAQsvI6s7y8PCbokLbS0tKsb+HM\n2kUmsHaRq7K9dlm3yISDrVue+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0v\nAAAAglAr7gKqu+7du5ts+PDhJrvwwgtNVlpaarLVq1d77zNr1iyTTZ061WRfffWV93wAAIDqjie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCHm+AaqM3yQvr/JvErPrrrvOm0+cONFk+fn5lV2OJGnx\n4sUmGzBggMnWrVuXjXLSVlpampfte4awdlH5WLvIVdleu7m8bmfMmGGyK6+80mTz58/3nj9nzhyT\nLV261GQFBQUp1bNnzx5vvnfv3pTOz2UHW7c88QUAAEAQaHwBAAAQBBpfAAAABIHGFwAAAEFg57YK\n6NGjh8keeOAB77G+Qbbly5ebbMSIESb7+OOPU65pyJAhJhs3bpzJRo4cabKbbrop5fsgt9WvX99k\no0aN8h47ZswYk/mGYcePH2+y9u3bm6xXr16plAgAOWnVqlUmKykpMZmvhzhYXlHTp0/35tdee63J\niouLM3rvqownvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAjs3FaGnj17muz55583mW9oSJLmzZtn\nMt8ubxs2bKhAdf8nL89uUuIbeLvoootM9tOf/jSte2cLu1+lr0WLFib797//7T22U6dOJlu2bJnJ\nfMNtv/jFL0zmnPPeJ921nwtYuzhQ06ZNTdamTRvvsXXq1DFZ//79Tfbss8+aLNnuXe+++25ZJUpi\n57Z0+XqIrl27pnz+qaeeajLfz/G6deua7NBDD/Ve8/zzzzeZb6fXXMbObQAAAAgejS8AAACCQOML\nAACAIND4AgAAIAjs3LafWrXsfx2+3c98g2wrVqzwXtO3Q8qmTZsqUN3B+YYUp02bZrK5c+dm/N7I\nHa1atcr4NYuKikzmG6po27at9/wQhtsQjh/84Acm+9nPfmaywYMHm6xZs2bea6Y6hD5o0KCUjpOk\nmjVrpnwsKu7VV19NKUtX9+7dTTZ//nzvsRdffLHJqttw28HwxBcAAABBoPEFAABAEGh8AQAAEAQa\nXwAAAASBxhcAAABB4K0O+7n66qtN1rFjR5Pt3r3bZAMHDvReszLe4JCOzZs3x10CYnTmmWdm/Jov\nvfSSyXxvQ+ncubP3/JCmiZGbOnTo4M2HDx9usgsuuMBkRx11VMZr8tm+fbvJFi1alJV7I3saN25s\nsrvuustkxcXF3vOTve0hFDzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+/+MUvUjpu2LBh\nJvvwww8zXQ6QFt+WpJdeeqnJSkpKvOcnG4wAKsK3Jbwk1alTx2SFhYWVXY4k/8Dl9OnTTda6dWvv\n+d/73vcyXpPPJ598YrIxY8aYzDe8vGTJkkqpCRXXsGFDb96lSxeT5efnm2z06NEm863lp556ynuf\nv/zlL2VUWL3xxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASB4bYK+M9//hN3CUCZmjZtarJTTz3V\nZP/617+8569YsSKl+xQVFZls7969JmvTpk1K10P15NtZSpIuueQSk82ZM8dkY8eOTfleJ598sslu\nv/12k/mGPWvXrm2yvLw8731KS0tTrikVvn9uSfr5z39usl27dmX03khfgwYNTDZhwgST+dadlN4O\nf++9957J7rvvvgpfrzrjiS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAhCkMNtvsEHSTrhhBNMtn37\ndpNFUZTxmoC4rF69Oq3zP//8c5MVFBSYrEOHDmndB7njkEMOMdmVV17pPbZFixYma9euncl8g0PO\nOe81e/ToUVaJ5ZJsuM3Ht3va008/bbIXX3zRZOyyltvOOussk91www1Zubfv+yPZrpyh44kvAAAA\ngkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAghDkWx1q1fL/Y9esWdNkO3fuNBlbFiMXnHfeeSkd\nN3HixLTu4/t+8n0vNWvWzHu+7w0A27ZtS6smxKtx48Ymq1+/vvfYVLf9HT58uMkqYyvh999/32Qz\nZ870Hvvaa6+ZrLCw0GRffvllhetB7ujSpUta52/cuNFkU6dONVmNGvaZ5R133GEy33bJkjR06FCT\nffPNN6mUWC3wxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2uB1++OEm69mzp8luvfXWlK+5\nZs0ak7Vq1cpk69evN9kLL7xgsunTp3vvU1RUlHJNiNcPf/hDk23YsMFk77zzTlr38Q2Azp8/32TD\nhg3znn/ooYeajOG23Ob7ebRp0ybvsb5BuGwZP368yR5++GGTbdmyJRvlIMeNGzfOZP/4xz9MtmPH\nDu/5f/3rX022Z88ek/mGOmfPnm2yt956y3ufadOmmWzIkCEm27p1q/f8XMcTXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEASG28rgG7zo3LmzyT744APv+W3atDHZwoULTdaiRQuT7dq1y2QfffSR9z6+\nYRJfNmjpU0jmAAAIzUlEQVTQIJNdcMEFJuvatav3Ppdeeqk3R7x8u2JdfPHFJvMNSiQbtEhHdR2K\nQMUlG7RxzlX4mm+//bY3nzNnjsmee+45k/l2qyopKalwPQhbcXGxyebNm5fx+/h2Jly5cqXJrr76\nau/5c+fONdnixYtNNmXKlApUV/XxxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2ZLvwfPvt\ntybz7Srly44//njvNRctWmSyY445xmS+wY8bbrjBZJ999pn3Pql6+eWXTeb7oPuJJ56Y1n2QXfXq\n1TNZy5YtTVZQUJCNcrzfS8n4vp+yVSeyZ+TIkd7ct2ulb9jX55xzzkmnJKBa8/19L0l/+tOfTOb7\n/pw5c6bJku3AmEt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAAAIIQ5HCbb0czSVq3bp3JfIM3l19+\nucnatm3rvaZvkM23c1ufPn1MVhk7avnuPW3aNJNddNFFGb834pefn2+yTp06eY/97rvvTOYbDK1b\nt67JfDsLJTN16lSTnXfeeSYrKipK+ZqoegoLC725b9BmwIABJmvevLnJ1q9f773m7NmzTXbXXXeZ\nLNmgM1CdTZ482WT9+/c32TXXXGOye++9t1Jqyiae+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ\n+AIAACAIeeWZvq7wTfLyKv8mGTBhwgST3X777Wld0/fGhJtvvtlkO3fuTOs+6XjuuedM1q1bN++x\nHTp0MNnatWszXpNPaWlpXlZutJ9cWbtHHnmkyTZu3JjWNYuLi03mm8z3vSnCt4VyefjecjJv3ry0\nrhkn1m75+CbMH3/8cZM1bNjQe77v77WlS5earFevXib75ptvUikxGNleu7m8bnNFnTp1TPbuu++a\nbMWKFSYbNGhQpdSUaQdbtzzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+NGjUy2Ycffmiy\nFi1apHzNW265xWSTJk0qX2GVzLftZ7KhkVNOOcVkURRlvCYfBoSSq1mzpsnGjx9vspEjR2ajnHL5\n4IMPTHbGGWeYbO/evdkop1KwdtPn+7nrGx6WpPPPPz+la37yyScmu+yyy0y2atWqlK5XHYU+3Obb\nKlvyD1v27dvXZLt37854TZVhzJgxJrv22mtNdtJJJ5ls69atlVJTOhhuAwAAQPBofAEAABAEGl8A\nAAAEgcYXAAAAQWC4rQw9evQw2Z/+9CeT1a9f33v+jh07TPbqq6+a7N577zXZypUrUymxXLp3726y\nl19+2WSfffaZ9/x27dplvKZUMSBUPr6BtyZNmpgs2dr1rRXfMJAv8w1AvPnmm977+HbUOuuss7zH\n5irWbuXwDUFK/l3+fLsb+rz//vsmu/HGG73H+gYzq5vQh9tatWrlzf/5z3+a7OmnnzbZr3/9a5Nt\n2LAh7boyzTfcdvfdd5vs+OOPN9maNWsqo6S0MNwGAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCwHBb\nBXTt2tVk999/v/fYk08+OaVr7tq1y2RDhw412dq1a73n+z5c3qVLF5NNnjzZZL4d655//nnvfQYN\nGuTNs4EBodzRqVMnkyUbBGK4rXKEvHYvueQSk82ZM6fC1/P9LJak6dOnV/iauSL04bajjz7am/t2\nLPUNCq9evdpkw4YN817znXfeMVlxcXFZJZZbnz59TPbAAw+YLD8/32Q/+MEPTPbtt99mprAMYrgN\nAAAAwaPxBQAAQBBofAEAABAEGl8AAAAEgeG2DEm2K9DgwYNN5tvJ5bDDDst4TT6+D8r7do0bN25c\nNsopFwaEcscRRxxhslWrVnmP3bt3r8m+//3vm6wqDlCkirVbOa677jpv/uijj2b0Pk8++aQ39/18\nr25CH25Lpm/fviabNWtWWtf07ejm69Feeuklk/Xu3Tvl+zRu3NhkvkG2e+65x2R33nlnyveJE8Nt\nAAAACB6NLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACAJvdYiBb6LSN53smxpt3759yvcpKCgw2eOP\nP26yCRMmpHzNODEZn9t8WxNL0plnnmky3zah69aty3hN2cLaLR/ftvAjR4402dlnn+09P9N/r914\n443efOrUqRm9T1XEWx38atasabJu3bqZbMSIESZLd0v2vDz7P0m6a37atGkmGz16tMk2bdqU1n2y\nhbc6AAAAIHg0vgAAAAgCjS8AAACCQOMLAACAIDDchpzBgFBuGz58uDd/6KGHTHbJJZeYzLdNZ65g\n7Urdu3f35tdcc43JfENCvi1VfUM+UuqDPuPHjzfZsmXLTPbyyy+ndL3qiOG29NSoYZ8vnnbaad5j\nfQPtP/zhD012xhlnmGzPnj0mmz17tvc+kydPNplv3ZeUlHjPzwUMtwEAACB4NL4AAAAIAo0vAAAA\ngkDjCwAAgCAw3IacwYBQbjv99NO9+d/+9jeT/eUvfzHZueeem+mSsia0tTt06FCTJdsh0reTpc/W\nrVtNtmTJEu+xH330kclefPFFk61YscJkuTzQUxkYbkMuYrgNAAAAwaPxBQAAQBBofAEAABAEGl8A\nAAAEgeE25IzQBoRQfYS2dn27TfXo0cN77Pz581O65saNG032+eefl68wlBvDbchFDLcBAAAgeDS+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCLzVATkjtMl4VB+sXeQq3uqAXMRbHQAAABA8Gl8AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQaDxBQAAQBBofAEAABCErOzcBgAAAMSN\nJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCD8P5Eh05z4VcPCAAAAAElF\nTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plots(x_imgs[:8], titles=preds[:8])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These predictions are pretty good!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Coding the Neural Net ourselves" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Recall that above we used PyTorch's `Sequential` to define a neural network with a linear layer, a non-linear layer (`ReLU`), and then another linear layer." ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Our code from above\n", "net = nn.Sequential(\n", " nn.Linear(28*28, 256),\n", " nn.ReLU(),\n", " nn.Linear(256, 10)\n", ").cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It turns out that `Linear` is defined by a matrix multiplication and then an addition. Let's try defining this ourselves. This will allow us to see exactly where matrix multiplication is used (we will dive in to how matrix multiplication works in the next section). \n", "\n", "Just as Numpy has `np.matmul` for matrix multiplication (in Python 3, this is equivalent to the `@` operator), PyTorch has `torch.matmul`. \n", "\n", "PyTorch class has two things: constructor (says parameters) and a forward method (how to calculate prediction using those parameters) The method `forward` describes how the neural net converts inputs to outputs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In PyTorch, the optimizer knows to try to optimize any attribute of type **Parameter**." ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_weights(*dims): return nn.Parameter(torch.randn(*dims)/dims[0])" ] }, { "cell_type": "code", "execution_count": 104, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class SimpleMnist(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.l1_w = get_weights(28*28, 256) # Layer 1 weights\n", " self.l1_b = get_weights(256) # Layer 1 bias\n", " self.l2_w = get_weights(256, 10) # Layer 2 weights\n", " self.l2_b = get_weights(10) # Layer 2 bias\n", "\n", " def forward(self, x):\n", " x = x.view(x.size(0), -1)\n", " x = torch.matmul(x, self.l1_w) + self.l1_b # Linear Layer\n", " x = x * (x > 0).float() # Non-linear Layer\n", " x = torch.matmul(x, self.l2_w) + self.l2_b # Linear Layer\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create our neural net and the optimizer. (We will use the same loss and metrics from above)." ] }, { "cell_type": "code", "execution_count": 102, "metadata": { "collapsed": true }, "outputs": [], "source": [ "net2 = SimpleMnist().cuda()\n", "opt=optim.Adam(net2.parameters())" ] }, { "cell_type": "code", "execution_count": 103, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7de8c8622d2f4d438dd8543bd24c7ff4" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2457a8edc94b429eadcbcb0928e1b748" } }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 0.1635 0.1474 0.9559]\n", "\n" ] } ], "source": [ "fit(net2, md, n_epochs=1, crit=loss, opt=opt, metrics=metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can check our predictions:" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAr4AAAF0CAYAAADFHDo6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl0ldW9//FPGFJGRVQQUUDBbhdUAcGpUusMCBekYguK\nlUnFoVa0ViYVRIu2KqAo3paKOBYQwQEHitAqsmq1oCxUHtGWkiqTIEIAISH5/XHouvz47kNOck7O\nk5P9fq11V3o/PMPXdid8fXK+z84rLS0VAAAAUN3ViLsAAAAAIBtofAEAABAEGl8AAAAEgcYXAAAA\nQaDxBQAAQBBofAEAABCEWnEXEDLnXCNJd0vqI6mppAJJ0yX9JoqikjhrAw7GOZcvaYSkyyW1krRJ\n0jRJ90VRtDvG0oAyOefOlPSApFMkfSNphqTR/NxFVeacqyXpLklXKdEzfCppZBRFr8daWI7hiW+8\nZkrqKmmgpBMlTVSiER4eY01AKu6TdIukUZLaSvqlEut2QpxFAWVxzrWV9GdJryuxdm+WdJOk2+Os\nC0jBQ5JulTROibX7hqSXnHMdY60qx+SxgUU8nHPHSloh6fL9/23NObdAUsMois6MrTigDM65TZKe\njaLo5v2yiUqs56bxVQYcnHPueUm1oii6bL/sIknfRlH0XnyVAck55+pK2ippYhRFI/bLl0gqiKKo\nf2zF5Rg+6hCTKIoKJB2W5I+Ls1kLUAGlsut0974cqJKcczUk9ZA0ZP88iqIF8VQEpKyNpHxJ7xyQ\nvyLpV9kvJ3fR+FYRzrnakgZI+pGkfjGXA5TlMUnDnHMzJX2gxK/dBkj631irAg6ulaSGkgqdc7Ml\nnS2pUNLDURRNjrMwoAw193098IHDJklHOOcOiaJoW5Zrykl8xrcKcM4tlfSdEp+b7BdF0UsxlwSU\nZZykOZL+rsST3pWSXpU0NsaagLIcue/rZElvSeom6QlJDzrnRsVWFVC2zyXtldTpgLz9vq8Ns1tO\n7qLxrRp+JqmzpKmSZjrnroi5HqAstyuxbgdKOlXSlUq8neTuGGsCylJ739dnoyh6PIqi5VEU3Svp\neUm/dM7lxVgbkFQURYWSnpF0i3PuLOdcTedcb0k/3XdIUXzV5RY+6lAF7Pu8b4Gk5c65BpIecc49\nz6t1UBU55xpr39tHoiiasS/+yDlXR9JU59zkKIq+jq9CIKnt+74uOyBfosRHdZpKWp/VioDU3STp\nECXWa4mkdyXdqcRHz76Jsa6cwhPfmDjnWjrnrtj3Xr79rVRi6K1JDGUBqWitxJOzVQfknyvxL9PH\nZb0iIDVfKNEwND4g/+/fhXxGElVWFEXboij6iaQjJDWNouhsSfUkrYqiiCe+KaLxjc8JSvza4uwD\n8pMk7ZK0JesVAan5z76v3z8gP/GAPweqlH2/Ln5X0v8c8EdnSfoiiqKd2a8KSI1zro9z7rQoijbv\n91u1fpKYCyoHPuoQn8VKTMP/3jl3o6TVks6RdJ2kJ6Io2hNjbUBSURStc87NkXSnc26dEu+jbivp\nDkkLoihaF2uBwMGNk7TAOTdS0ixJPZX4vPr1sVYFlO1KSac4534u6UslNl9ppcSwJlLEE9+YRFG0\nV4kfuG8rsV3mSv3fjizs3IaqbqASv7F4TNJnkn4vaZ6kyw5yDhC7KIreUmKdXi7pEyWah+ujKPpD\nrIUBZRsiaamkuUo8cDhB0jlRFG2Ktaocw85tAAAACAJPfAEAABAEGl8AAAAEgcYXAAAAQaDxBQAA\nQBCy8jqzvLw8JuiQttLS0qxvJ8raRSawdpGrsr12WbfIhIOtW574AgAAIAg0vgAAAAgCjS8AAACC\nQOMLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAhZ2bkNAICqpkGDBiYbMmSI\nyXr37u09v1evXiYrLCxMvzAAlYYnvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAgCb3UA\nAATpqquuMtnEiRNTPr9du3Yme++999KqCUDl4okvAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAsNt\nFdC+fXuTDR8+3Hts69atTVavXj2TjRo1ymSHHnqoyV5//XXvfbZv3+7NAQDSwIEDTTZp0iSTFRUV\nmeyBBx7wXnPZsmVp1wUgu3jiCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgpBXWlpa+TfJy6v8m1SS\nBg0amGzt2rUma9SoUTbK0ZdffunNfcN1L7zwQmWXk1WlpaV52b5nLq9dH9867dOnj/fYjh07mqxL\nly4m832PbNmyxWRHHXWU9z7r16832ZNPPmmyP/zhDybbu3ev95pVDWs3u3r16mWyuXPnmmznzp0m\nu/POO01Wnt3cqptsr92Q1y0y52Drlie+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy3laFhw4Ym\ne+2110y2efNm7/nLly83mW9oqGXLliY79thjTVa3bl3vfTZs2GCyM888M6XjcgUDQuVzzDHHmGze\nvHkm863HZLZt22Yy3xqvXbu2yXzfS5LUpEkTkzVt2tRkl19+ucnefvttk61bt857nzixditHfn6+\nN58+fbrJ+vfvb7JFixaZ7IILLki/sGqE4TbkIobbAAAAEDwaXwAAAASBxhcAAABBoPEFAABAEBhu\nq8KOOOIIk912223eY335oEGDTDZjxoz0C4sJA0Lls2zZMpO1b9/eZAsXLvSef+utt5rs66+/Nplv\n57XyOPLII032+uuvm8w5Z7IRI0aY7NFHH02rnsrA2q0co0eP9ubjx4832TPPPGOywYMHm6y4uDj9\nwqoRhtvS06xZM5Ndf/313mN9eVFRkcl8u8fee++9JvP9HSBJBQUF3rw6YbgNAAAAwaPxBQAAQBBo\nfAEAABAEGl8AAAAEgcYXAAAAQeCtDjmmV69e3ty3Fe3DDz9ssptvvjnjNWULk/HJ+SaHv/zyS5PN\nmjXLZFdccYX3mnv37k2/sAp69tlnTdavXz+TderUyWQffvhhpdSUDtZu+jp37myyJUuWeI9ds2aN\nydq1a2eyONd4ruCtDqk7/vjjTTZ16lSTXXjhhdkoR7t37/bmZ511lsmSvQEiV/FWBwAAAASPxhcA\nAABBoPEFAABAEGh8AQAAEIRacReA5A477DCTjRo1KuXzjz766EyWgyqsQ4cOJsvLs5/t/+qrr0wW\n94DPGWecYbL+/fubbPHixSbz/XNXxeE2lE+NGvaZjG976vz8fO/5r7zyisniXueoXpo3b26ylStX\nmqxWLdtmTZw40XvNRx55JKX7nHjiiSb73e9+Z7JGjRp57+Mbcvb9HPZtUV8d8MQXAAAAQaDxBQAA\nQBBofAEAABAEGl8AAAAEgZ3bqoj27dubbPbs2SZr06aN9/zPPvvMZL7dYQoKCipQXdXA7lflU1JS\nYrKNGzea7LTTTvOev3bt2ozW07BhQ2++dOlSk61evdpkvh3mfDslffzxxxWornKxdssn1Z0Ik7np\npptMNmXKlLRqChU7t/lNnjzZZMOGDTPZ1VdfbbKnnnoq4/XccMMNJps0aZL32Jo1a5ps1apVJvMN\nvG3btq0C1WUfO7cBAAAgeDS+AAAACAKNLwAAAIJA4wsAAIAgMNwWg6uuuspkd999t8mOPfZYk+3a\ntct7zZ49e5rMt9NVLmNAqHzGjh1rsjvuuMNkURR5z+/atavJ0hmOXLBggTf/8Y9/bLJOnTqZzLcr\nUq5g7ZbPoEGDTPbHP/7RZAsXLvSe3717d5Oxc1vFhD7cdsghh3hz3wDu9OnTTebbcTBbkv1sP+GE\nE1I637fD3K233ppWTdnCcBsAAACCR+MLAACAIND4AgAAIAg0vgAAAAgCjS8AAACCwFsdMqRBgwbe\n/Fe/+pXJxowZY7IaNey/g2zZssVkXbp08d7Ht91gdcNkfPnUqVPHZDNmzDBZ3759ved//vnnJjvn\nnHNMtm7dOpM99thjJrvmmmu897nttttM5psmzmWs3eRq1aplsk8//dRkLVu2NNlxxx3nvWZ5tjfG\nwYX+VodkW7r/7W9/M9mFF15osrfeeivjNaWqT58+3vzFF180ma8X3Lp1q8l8b4TYvHlzBaqrXLzV\nAQAAAMGj8QUAAEAQaHwBAAAQBBpfAAAABMFOFaBCnnzySW/+k5/8JKXzX3jhBZNNmjTJZCEMsSEz\nvvvuO5MNHTrUZE2aNPGe79tK+K9//avJZs+ebbIBAwaYbM6cOd77VLdBNpSPb7iydevWJrvuuutM\nFvcQW7du3UzWq1cvk73xxhsm823h7fueRbw6duyY8rHLly+vxErK77XXXvPmvsFl3/ecbz3u2LEj\n/cJixhNfAAAABIHGFwAAAEGg8QUAAEAQaHwBAAAQBIbbMsT3wfDymDp1qsmWLl2a1jWBA23fvt1k\nvXv39h47duxYk918880mGzFiREr3fuSRR1I6DmFp0aJFSsfl5+dXciXJDRw40Jv7dij07Zg4bNgw\nk/l2xZo3b573PoMHDy6jQlSWJUuWePOSkhKT/fnPfzZZz549Tebb7bIyOOe8uW+Ndu3a1WT16tUz\nWXUYwOSJLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACALDbRni24VHktq3b1/h830Db/fdd5/3/K++\n+iql+wAH2rZtmze/8847TXbhhRearG3btind54ILLvDmyYZHEIY2bdqkdFy2dq1s1KiRyR566CHv\nsb4hoeLiYpP5hp66dOliMt+OhxLDbXH6+OOPvfmrr75qMt+g8Keffmoy305+kn93y0WLFpmsefPm\nJvMNsvl2f5WkZs2amcy3bl966SXv+bmOJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAIeaWlpZV/\nk7y8yr9JzOrWrevNn3nmGZN16tTJZKnuXrR+/XpvPmjQIJO9+eabKV0zV5SWluZl+54hrN1kunfv\nbrK5c+earHbt2ildb8+ePd78+uuvN9n06dNTumauYO0mN3/+fJN17NjRZEcffXQ2yvHuTphsuM33\n833y5MkmW7t2rcl8A04nnXSS9z5x7lqX7bWbK+vW93f+hAkTTHbTTTeldZ8tW7aYrHHjxmld0+ey\nyy4zmW/YLlccbN3yxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBYMviDNm1a5c3v+KK\nK0xWq5b9rz3ZtrEHOuqoo7y5b9r+lltuMdnjjz+e0n2Ac88912S+t8D06dPHZL5JZN8Wn5J/a+6v\nv/7aZK+88or3fOS2008/3WTJ3gBS1fi2ij/mmGNM9vvf/95kp5xyismq25t4qjPf3/m+N4LMmjXL\nZL6+IJmmTZumdFxRUZHJfN9bknTccceZbOfOnSnXlOt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAILAlsVVxMknn2yyiRMnmsw3cJSMb5vMVq1alauuqoRtXyuHb+1J0vvvv28y3yCab6DDx7clpiT9\n8Y9/NFlenv2ful27dibzrfGqiLWbnG/wq2fPniarjC2LfevMt54ffPDBtO7j+3v2scceM9moUaO8\n52/fvj2t+6eDLYtz19NPP+3NfcN13bp1M9mCBQsyXlO2sGUxAAAAgkfjCwAAgCDQ+AIAACAINL4A\nAAAIAju3laFevXomq4wdTlasWGGyvn37muyJJ57wnt+7d2+TtWjRwmTNmjUz2bp161IpEdVUw4YN\nvblvh8EXXnihwveZPXu2N2/ZsqXJ7r//fpN16tTJZLky3IbyadSokcl8gzrPPPOM93zf2u3Xr5/J\nGjdubLLu3bunUqIkaceOHSZbsmSJyX7729+abPHixSnfB6hsrVu3jruErOGJLwAAAIJA4wsAAIAg\n0PgCAAAgCDS+AAAACALDbfvxfbjbN6gwf/58k61cudJ7Td/g2JAhQ0xWu3ZtkzVv3txkbdq08d7H\n54svvkipHoStQ4cO3nz9+vUm830/pGvKlCkmu/rqq012ww03mGzu3LkZrwfZtXz5cpMNHTrUZL7d\npnxZurZt22ayZIOZ99xzj8n+/e9/Z7wmoCIKCwvjLqFK4okvAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAsNt+7nssstMdtRRR5ls8ODBGb93Xl6eyUpLS1M+3/ch9mHDhqVVE8Lg2+FPkv7+979n5f57\n9uwx2TfffGOyH/3oRybz7by1ZcuWzBSGrHjuuedM5tu1cvXq1SarWbOm95rJ8gM9++yzJluzZo3J\nfIPCQFX39ttve/Nrr73WZE2aNKnscqoMnvgCAAAgCDS+AAAACAKNLwAAAIJA4wsAAIAg0PgCAAAg\nCLzVYT+HH3543CX8f+bMmWOy8ePHe4/duHGjyXxbzgIHSvb2kC5dupisX79+Jlu0aJHJGjRoYLL8\n/HzvfU488USTnXrqqSZ79NFHTcYbHHLft99+a7Lzzz8/hkqA6qVGDf+zTd9bpHw9RHXFE18AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEhtv2M2rUKJMtXLjQZAMGDDDZ0Ucf7b2mb3DD55FHHjHZO++8\nY7Li4uKUrgek6tNPP/Xmvu2AfdvLbt682WTlGW7zDVq8++67Jhs7dqz3fACAVVJS4s2TDTSHgie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCAy37aeoqMhkb775ZkoZkKveeOMNbz5lyhST+XZz69Ch\nQ1r3Hz16tMmeeOIJk7FLGwBUjosuushkU6dOjaGSyscTXwAAAASBxhcAAABBoPEFAABAEGh8AQAA\nEASG24DAbdiwwZv/8pe/zHIlAIBMKSwsTPnYWrXCaQd54gsAAIAg0PgCAAAgCDS+AAAACAKNLwAA\nAIJA4wsAAIAg5JWWllb+TfLyKv8mqPZKS0vzsn1P1i4ygbWLXJXttcu6zZxGjRp5c9/277t27TJZ\n/fr1M15Tthxs3fLEFwAAAEGg8QUAAEAQaHwBAAAQBBpfAAAABIHhNuQMBoSQq1i7yFUMtyEXMdwG\nAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCkJXhNgAAACBuPPEFAABAEGh8AQAAEAQaXwAAAASBxhcA\nAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEIRacRcAyTl3pqQHJJ0i6RtJMySNjqKoJNbCgBQ45w6R9KmkoiiKWsVcDnBQ\nzrk1klp6/ujRKIpuzG41QGqcc40k3S2pj6SmkgokTZf0G3qF8qHxjZlzrq2kP0u6T9IASacqsZi3\nSZoQY2lAqu6RdKSkr+IuBEjRg0o8bNjfjjgKAVI0U1IrSQMl/UvSxZIelrRLifWMFNH4xu8OSa9H\nUXTPvv//X865rZK+jbEmICXOuc6Shkp6XtKPYy4HSFVhFEXr4y4CSIVz7lhJp0m6PIqit/bFU5xz\nvST1FY1vudD4xsg5V0NSD0lD9s+jKFoQT0VA6pxzNSX9r6TfSSoVjS8AZFwURQWSDkvyx8XZrKU6\noPGNVytJDSUVOudmSzpbUqGkh6MomhxnYUAKblRi/f5G0siYawGAIDjnaivx0cgfSeoXczk5h8Y3\nXkfu+zpZ0kNKNBAXS3rQOVc/iqLfxFYZcBDOueaSxkv6SRRFu51zcZcElEdn59wCSScr8dnepyVN\niKJod7xlAQfnnFsq6XRJX0vqF0XRSzGXlHN4nVm8au/7+mwURY9HUbQ8iqJ7lfi85C+dc3kx1gYc\nzMOSXo6iaGHchQDltElSPSU+F3mRpEmSblPiYztAVfczSZ0lTZU00zl3Rcz15Bye+MZr+76vyw7I\nlyjxa4ymkhjAQJXinOupxMdy2sVdC1BeURSdekC0Yt8r+e5xzo2Joug/cdQFpGLf530LJC13zjWQ\n9Ihz7nleaZY6Gt94fSGpRFLjA/L/Ponflt1ygJRcKulwSV/t9xGHGpLynHPFku6OoujuuIoDKuDD\nfV+bSaLxRZXinGspqYukmVEU7T/MtlKJobcm4iFZyvioQ4yiKCqU9K6k/zngj86S9EUURTuzXxVQ\npjFKfDayw37/97gS7/H9738GqhyX8JRz7vgD/ugUSXsl/TOGsoCynCDpGSV+07a/k5R4j++WrFeU\nw3jiG79xkhY450ZKmiWppxKf4bk+1qqAJKIo+lLSl/tnzrmNSuzctjKeqoCUFCjRPMx0zt2qxNPd\nH0v6taRpURRtjrM4IInFkj6Q9Hvn3I2SVks6R9J1kp6IomhPjLXlHJ74xmzfy6gvk3S5pE8k3Szp\n+iiK/hBrYQBQzez7Ldq5SnzMbKakVUr8BuN3SryeD6hyoijaq8RDsbclzVDiIw63KvHgbHiMpeWk\nvNLS0rhrAAAAACodT3wBAAAQBBpfAAAABIHGFwAAAEGg8QUAAEAQsvI6s7y8PCbokLbS0tKsb+HM\n2kUmsHaRq7K9dlm3yISDrVue+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0v\nAAAAglAr7gKqu+7du5ts+PDhJrvwwgtNVlpaarLVq1d77zNr1iyTTZ061WRfffWV93wAAIDqjie+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCHm+AaqM3yQvr/JvErPrrrvOm0+cONFk+fn5lV2OJGnx\n4sUmGzBggMnWrVuXjXLSVlpampfte4awdlH5WLvIVdleu7m8bmfMmGGyK6+80mTz58/3nj9nzhyT\nLV261GQFBQUp1bNnzx5vvnfv3pTOz2UHW7c88QUAAEAQaHwBAAAQBBpfAAAABIHGFwAAAEFg57YK\n6NGjh8keeOAB77G+Qbbly5ebbMSIESb7+OOPU65pyJAhJhs3bpzJRo4cabKbbrop5fsgt9WvX99k\no0aN8h47ZswYk/mGYcePH2+y9u3bm6xXr16plAgAOWnVqlUmKykpMZmvhzhYXlHTp0/35tdee63J\niouLM3rvqownvgAAAAgCjS8AAACCQOMLAACAIND4AgAAIAjs3FaGnj17muz55583mW9oSJLmzZtn\nMt8ubxs2bKhAdf8nL89uUuIbeLvoootM9tOf/jSte2cLu1+lr0WLFib797//7T22U6dOJlu2bJnJ\nfMNtv/jFL0zmnPPeJ921nwtYuzhQ06ZNTdamTRvvsXXq1DFZ//79Tfbss8+aLNnuXe+++25ZJUpi\n57Z0+XqIrl27pnz+qaeeajLfz/G6deua7NBDD/Ve8/zzzzeZb6fXXMbObQAAAAgejS8AAACCQOML\nAACAIND4AgAAIAjs3LafWrXsfx2+3c98g2wrVqzwXtO3Q8qmTZsqUN3B+YYUp02bZrK5c+dm/N7I\nHa1atcr4NYuKikzmG6po27at9/wQhtsQjh/84Acm+9nPfmaywYMHm6xZs2bea6Y6hD5o0KCUjpOk\nmjVrpnwsKu7VV19NKUtX9+7dTTZ//nzvsRdffLHJqttw28HwxBcAAABBoPEFAABAEGh8AQAAEAQa\nXwAAAASBxhcAAABB4K0O+7n66qtN1rFjR5Pt3r3bZAMHDvReszLe4JCOzZs3x10CYnTmmWdm/Jov\nvfSSyXxvQ+ncubP3/JCmiZGbOnTo4M2HDx9usgsuuMBkRx11VMZr8tm+fbvJFi1alJV7I3saN25s\nsrvuustkxcXF3vOTve0hFDzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+/+MUvUjpu2LBh\nJvvwww8zXQ6QFt+WpJdeeqnJSkpKvOcnG4wAKsK3Jbwk1alTx2SFhYWVXY4k/8Dl9OnTTda6dWvv\n+d/73vcyXpPPJ598YrIxY8aYzDe8vGTJkkqpCRXXsGFDb96lSxeT5efnm2z06NEm863lp556ynuf\nv/zlL2VUWL3xxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAASB4bYK+M9//hN3CUCZmjZtarJTTz3V\nZP/617+8569YsSKl+xQVFZls7969JmvTpk1K10P15NtZSpIuueQSk82ZM8dkY8eOTfleJ598sslu\nv/12k/mGPWvXrm2yvLw8731KS0tTrikVvn9uSfr5z39usl27dmX03khfgwYNTDZhwgST+dadlN4O\nf++9957J7rvvvgpfrzrjiS8AAACCQOMLAACAIND4AgAAIAg0vgAAAAhCkMNtvsEHSTrhhBNMtn37\ndpNFUZTxmoC4rF69Oq3zP//8c5MVFBSYrEOHDmndB7njkEMOMdmVV17pPbZFixYma9euncl8g0PO\nOe81e/ToUVaJ5ZJsuM3Ht3va008/bbIXX3zRZOyyltvOOussk91www1Zubfv+yPZrpyh44kvAAAA\ngkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAghDkWx1q1fL/Y9esWdNkO3fuNBlbFiMXnHfeeSkd\nN3HixLTu4/t+8n0vNWvWzHu+7w0A27ZtS6smxKtx48Ymq1+/vvfYVLf9HT58uMkqYyvh999/32Qz\nZ870Hvvaa6+ZrLCw0GRffvllhetB7ujSpUta52/cuNFkU6dONVmNGvaZ5R133GEy33bJkjR06FCT\nffPNN6mUWC3wxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2uB1++OEm69mzp8luvfXWlK+5\nZs0ak7Vq1cpk69evN9kLL7xgsunTp3vvU1RUlHJNiNcPf/hDk23YsMFk77zzTlr38Q2Azp8/32TD\nhg3znn/ooYeajOG23Ob7ebRp0ybvsb5BuGwZP368yR5++GGTbdmyJRvlIMeNGzfOZP/4xz9MtmPH\nDu/5f/3rX022Z88ek/mGOmfPnm2yt956y3ufadOmmWzIkCEm27p1q/f8XMcTXwAAAASBxhcAAABB\noPEFAABAEGh8AQAAEASG28rgG7zo3LmzyT744APv+W3atDHZwoULTdaiRQuT7dq1y2QfffSR9z6+\nYRJfNmjpU0jmAAAIzUlEQVTQIJNdcMEFJuvatav3Ppdeeqk3R7x8u2JdfPHFJvMNSiQbtEhHdR2K\nQMUlG7RxzlX4mm+//bY3nzNnjsmee+45k/l2qyopKalwPQhbcXGxyebNm5fx+/h2Jly5cqXJrr76\nau/5c+fONdnixYtNNmXKlApUV/XxxBcAAABBoPEFAABAEGh8AQAAEAQaXwAAAAQhyOG2ZLvwfPvt\ntybz7Srly44//njvNRctWmSyY445xmS+wY8bbrjBZJ999pn3Pql6+eWXTeb7oPuJJ56Y1n2QXfXq\n1TNZy5YtTVZQUJCNcrzfS8n4vp+yVSeyZ+TIkd7ct2ulb9jX55xzzkmnJKBa8/19L0l/+tOfTOb7\n/pw5c6bJku3AmEt44gsAAIAg0PgCAAAgCDS+AAAACAKNLwAAAIIQ5HCbb0czSVq3bp3JfIM3l19+\nucnatm3rvaZvkM23c1ufPn1MVhk7avnuPW3aNJNddNFFGb834pefn2+yTp06eY/97rvvTOYbDK1b\nt67JfDsLJTN16lSTnXfeeSYrKipK+ZqoegoLC725b9BmwIABJmvevLnJ1q9f773m7NmzTXbXXXeZ\nLNmgM1CdTZ482WT9+/c32TXXXGOye++9t1Jqyiae+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ\n+AIAACAIeeWZvq7wTfLyKv8mGTBhwgST3X777Wld0/fGhJtvvtlkO3fuTOs+6XjuuedM1q1bN++x\nHTp0MNnatWszXpNPaWlpXlZutJ9cWbtHHnmkyTZu3JjWNYuLi03mm8z3vSnCt4VyefjecjJv3ry0\nrhkn1m75+CbMH3/8cZM1bNjQe77v77WlS5earFevXib75ptvUikxGNleu7m8bnNFnTp1TPbuu++a\nbMWKFSYbNGhQpdSUaQdbtzzxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQWC4bT+NGjUy2Ycffmiy\nFi1apHzNW265xWSTJk0qX2GVzLftZ7KhkVNOOcVkURRlvCYfBoSSq1mzpsnGjx9vspEjR2ajnHL5\n4IMPTHbGGWeYbO/evdkop1KwdtPn+7nrGx6WpPPPPz+la37yyScmu+yyy0y2atWqlK5XHYU+3Obb\nKlvyD1v27dvXZLt37854TZVhzJgxJrv22mtNdtJJJ5ls69atlVJTOhhuAwAAQPBofAEAABAEGl8A\nAAAEgcYXAAAAQWC4rQw9evQw2Z/+9CeT1a9f33v+jh07TPbqq6+a7N577zXZypUrUymxXLp3726y\nl19+2WSfffaZ9/x27dplvKZUMSBUPr6BtyZNmpgs2dr1rRXfMJAv8w1AvPnmm977+HbUOuuss7zH\n5irWbuXwDUFK/l3+fLsb+rz//vsmu/HGG73H+gYzq5vQh9tatWrlzf/5z3+a7OmnnzbZr3/9a5Nt\n2LAh7boyzTfcdvfdd5vs+OOPN9maNWsqo6S0MNwGAACA4NH4AgAAIAg0vgAAAAgCjS8AAACCwHBb\nBXTt2tVk999/v/fYk08+OaVr7tq1y2RDhw412dq1a73n+z5c3qVLF5NNnjzZZL4d655//nnvfQYN\nGuTNs4EBodzRqVMnkyUbBGK4rXKEvHYvueQSk82ZM6fC1/P9LJak6dOnV/iauSL04bajjz7am/t2\nLPUNCq9evdpkw4YN817znXfeMVlxcXFZJZZbnz59TPbAAw+YLD8/32Q/+MEPTPbtt99mprAMYrgN\nAAAAwaPxBQAAQBBofAEAABAEGl8AAAAEgeG2DEm2K9DgwYNN5tvJ5bDDDst4TT6+D8r7do0bN25c\nNsopFwaEcscRRxxhslWrVnmP3bt3r8m+//3vm6wqDlCkirVbOa677jpv/uijj2b0Pk8++aQ39/18\nr25CH25Lpm/fviabNWtWWtf07ejm69Feeuklk/Xu3Tvl+zRu3NhkvkG2e+65x2R33nlnyveJE8Nt\nAAAACB6NLwAAAIJA4wsAAIAg0PgCAAAgCDS+AAAACAJvdYiBb6LSN53smxpt3759yvcpKCgw2eOP\nP26yCRMmpHzNODEZn9t8WxNL0plnnmky3zah69aty3hN2cLaLR/ftvAjR4402dlnn+09P9N/r914\n443efOrUqRm9T1XEWx38atasabJu3bqZbMSIESZLd0v2vDz7P0m6a37atGkmGz16tMk2bdqU1n2y\nhbc6AAAAIHg0vgAAAAgCjS8AAACCQOMLAACAIDDchpzBgFBuGz58uDd/6KGHTHbJJZeYzLdNZ65g\n7Urdu3f35tdcc43JfENCvi1VfUM+UuqDPuPHjzfZsmXLTPbyyy+ndL3qiOG29NSoYZ8vnnbaad5j\nfQPtP/zhD012xhlnmGzPnj0mmz17tvc+kydPNplv3ZeUlHjPzwUMtwEAACB4NL4AAAAIAo0vAAAA\ngkDjCwAAgCAw3IacwYBQbjv99NO9+d/+9jeT/eUvfzHZueeem+mSsia0tTt06FCTJdsh0reTpc/W\nrVtNtmTJEu+xH330kclefPFFk61YscJkuTzQUxkYbkMuYrgNAAAAwaPxBQAAQBBofAEAABAEGl8A\nAAAEgeE25IzQBoRQfYS2dn27TfXo0cN77Pz581O65saNG032+eefl68wlBvDbchFDLcBAAAgeDS+\nAAAACAKNLwAAAIJA4wsAAIAg0PgCAAAgCLzVATkjtMl4VB+sXeQq3uqAXMRbHQAAABA8Gl8AAAAE\ngcYXAAAAQaDxBQAAQBBofAEAABAEGl8AAAAEgcYXAAAAQaDxBQAAQBBofAEAABCErOzcBgAAAMSN\nJ74AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4A\nAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAI\nAo0vAAAAgkDjCwAAgCDQ+AIAACAINL4AAAAIAo0vAAAAgkDjCwAAgCD8P5Eh05z4VcPCAAAAAElF\nTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "preds = predict(net2, md.val_dl)\n", "preds = np.argmax(preds, axis=1)\n", "plots(x_imgs[:8], titles=preds[:8])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## what torch.matmul (matrix multiplication) is doing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's dig in to what we were doing with `torch.matmul`: matrix multiplication. First, let's start with a simpler building block: **broadcasting**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Element-wise operations " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Broadcasting and element-wise operations are supported in the same way by both numpy and pytorch.\n", "\n", "Operators (+,-,\\*,/,>,<,==) are usually element-wise.\n", "\n", "Examples of element-wise operations:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": true }, "outputs": [], "source": [ "a = np.array([10, 6, -4])\n", "b = np.array([2, 8, 7])" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([12, 14, 3])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a + b" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([False, True, True], dtype=bool)" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a < b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Broadcasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The term **broadcasting** describes how arrays with different shapes are treated during arithmetic operations. The term broadcasting was first used by Numpy, although is now used in other libraries such as [Tensorflow](https://www.tensorflow.org/performance/xla/broadcasting) and Matlab; the rules can vary by library.\n", "\n", "From the [Numpy Documentation](https://docs.scipy.org/doc/numpy-1.10.0/user/basics.broadcasting.html):\n", "\n", " The term broadcasting describes how numpy treats arrays with \n", " different shapes during arithmetic operations. Subject to certain \n", " constraints, the smaller array is “broadcast” across the larger \n", " array so that they have compatible shapes. Broadcasting provides a \n", " means of vectorizing array operations so that looping occurs in C\n", " instead of Python. It does this without making needless copies of \n", " data and usually leads to efficient algorithm implementations.\n", " \n", "In addition to the efficiency of broadcasting, it allows developers to write less code, which typically leads to fewer errors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*This section was adapted from [Chapter 4](http://nbviewer.jupyter.org/github/fastai/numerical-linear-algebra/blob/master/nbs/4.%20Compressed%20Sensing%20of%20CT%20Scans%20with%20Robust%20Regression.ipynb#4.-Compressed-Sensing-of-CT-Scans-with-Robust-Regression) of the fast.ai [Computational Linear Algebra](https://github.com/fastai/numerical-linear-algebra) course.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting with a scalar" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([10, 6, -4])" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ True, True, False], dtype=bool)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a > 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How are we able to do a > 0? 0 is being **broadcast** to have the same dimensions as a.\n", "\n", "Remember above when we normalized our dataset by subtracting the mean (a scalar) from the entire data set (a matrix) and dividing by the standard deviation (another scalar)? We were using broadcasting!\n", "\n", "Other examples of broadcasting with a scalar:" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([11, 7, -3])" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a + 1" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2, 3],\n", " [4, 5, 6],\n", " [7, 8, 9]])" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = np.array([[1, 2, 3], [4,5,6], [7,8,9]]); m" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 2, 4, 6],\n", " [ 8, 10, 12],\n", " [14, 16, 18]])" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m * 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting a vector to a matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also broadcast a vector to a matrix:" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([10, 20, 30])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c = np.array([10,20,30]); c" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[11, 22, 33],\n", " [14, 25, 36],\n", " [17, 28, 39]])" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Although numpy does this automatically, you can also use the `broadcast_to` method:" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[10, 20, 30],\n", " [10, 20, 30],\n", " [10, 20, 30]])" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(c, (3,3))" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3,)" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The numpy `expand_dims` method lets us convert the 1-dimensional array `c` into a 2-dimensional array (although one of those dimensions has value 1)." ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 3)" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.expand_dims(c,0).shape" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[11, 22, 33],\n", " [14, 25, 36],\n", " [17, 28, 39]])" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + np.expand_dims(c,0)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3, 1)" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.expand_dims(c,1).shape" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[11, 12, 13],\n", " [24, 25, 26],\n", " [37, 38, 39]])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m + np.expand_dims(c,1)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([[10, 10, 10],\n", " [20, 20, 20],\n", " [30, 30, 30]])" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(np.expand_dims(c,1), (3,3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Broadcasting Rules" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When operating on two arrays, Numpy/PyTorch compares their shapes element-wise. It starts with the **trailing dimensions**, and works its way forward. Two dimensions are **compatible** when\n", "\n", "- they are equal, or\n", "- one of them is 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Arrays do not need to have the same number of dimensions. For example, if you have a $256 \\times 256 \\times 3$ array of RGB values, and you want to scale each color in the image by a different value, you can multiply the image by a one-dimensional array with 3 values. Lining up the sizes of the trailing axes of these arrays according to the broadcast rules, shows that they are compatible:\n", "\n", " Image (3d array): 256 x 256 x 3\n", " Scale (1d array): 3\n", " Result (3d array): 256 x 256 x 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [numpy documentation](https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html#general-broadcasting-rules) includes several examples of what dimensions can and can not be broadcast together." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Matrix Multiplication" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are going to use broadcasting to define matrix multiplication." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Matrix-Vector Multiplication" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[1, 2, 3],\n", " [4, 5, 6],\n", " [7, 8, 9]]), array([10, 20, 30]))" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m, c" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([140, 320, 500])" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m @ c # np.matmul(m, c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We get the same answer using `torch.matmul`:" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\n", " 140\n", " 320\n", " 500\n", "[torch.LongTensor of size 3]" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.matmul(torch.from_numpy(m), torch.from_numpy(c))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following is **NOT** matrix multiplication. What is it?" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 10, 40, 90],\n", " [ 40, 100, 180],\n", " [ 70, 160, 270]])" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m * c" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([140, 320, 500])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(m * c).sum(axis=1)" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([10, 20, 30])" ] }, "execution_count": 115, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[10, 20, 30],\n", " [10, 20, 30],\n", " [10, 20, 30]])" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.broadcast_to(c, (3,3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From a machine learning perspective, matrix multiplication is a way of creating features by saying how much we want to weight each input column. **Different features are different weighted averages of the input columns**. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The website [matrixmultiplication.xyz](http://matrixmultiplication.xyz/) provides a nice visualization of matrix multiplcation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Draw a picture" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[10, 40],\n", " [20, 0],\n", " [30, -5]])" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = np.array([[10,40],[20,0],[30,-5]]); n" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[140, 100],\n", " [320, 280],\n", " [500, 460]])" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m @ n" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([140, 320, 500])" ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(m * n[:,0]).sum(axis=1)" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 25, 130, 235])" ] }, "execution_count": 119, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(m * n[:,1]).sum(axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Homework: another use of broadcasting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you want to test your understanding of the above tutorial. I encourage you to work through it again, only this time use **CIFAR 10**, a dataset that consists of 32x32 *color* images in 10 different categories. Color images have an extra dimension, containing RGB values, compared to black & white images.\n", "\n", "\"\"\n", "
\n", "(source: [Cifar 10](https://www.cs.toronto.edu/~kriz/cifar.html))\n", "
\n", "\n", "Fortunately, broadcasting will make it relatively easy to add this extra dimension (for color RGB), but you will have to make some changes to the code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other applications of Matrix and Tensor Products" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some other examples of where matrix multiplication arises. This material is taken from [Chapter 1](http://nbviewer.jupyter.org/github/fastai/numerical-linear-algebra/blob/master/nbs/1.%20Why%20are%20we%20here.ipynb) of my [Computational Linear Algebra](https://github.com/fastai/numerical-linear-algebra) course. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Matrix-Vector Products:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The matrix below gives the probabilities of moving from 1 health state to another in 1 year. If the current health states for a group are:\n", "- 85% asymptomatic\n", "- 10% symptomatic\n", "- 5% AIDS\n", "- 0% death\n", "\n", "what will be the % in each health state in 1 year?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"floating(Source: [Concepts of Markov Chains](https://www.youtube.com/watch?v=0Il-y_WLTo4))" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "#### Answer" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.765 ],\n", " [ 0.1525],\n", " [ 0.0645],\n", " [ 0.018 ]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Exercise: Use Numpy to compute the answer to the above\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Matrix-Matrix Products" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"floating(Source: [Several Simple Real-world Applications of Linear Algebra Tools](https://www.mff.cuni.cz/veda/konference/wds/proc/pdf06/WDS06_106_m8_Ulrychova.pdf))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Answer" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 50. , 49. ],\n", " [ 58.5, 61. ],\n", " [ 43.5, 43.5]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Exercise: Use Numpy to compute the answer to the above\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## End" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A Tensor is a *multi-dimensional matrix containing elements of a single data type*: a group of data, all with the same type (e.g. A Tensor could store a 4 x 4 x 6 matrix of 32-bit signed integers)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }