{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by creating a simple linear model for digit recognition." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Accessing the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll download the famous [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "url = 'https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz'\n", "tgz = 'mnist_png.tgz'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use `urlsave` from [fastcore](https://fastcore.fast.ai/) to download the dataset. However, we don't want to re-download it if we've downloaded it before. We can use Python's `Path` class to make it more convenient to work with the filesystem." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from fastcore.all import urlsave\n", "\n", "cfghome = Path.home()/'.fastai'\n", "archive = cfghome/'archive'\n", "if not (archive/tgz).exists():\n", " archive.mkdir(exist_ok=True, parents=True)\n", " urlsave(url, archive)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can follow similar logic for decompressing the archive:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from fastcore.all import untar_dir\n", "data = cfghome/'data'\n", "dest = data/'mnist_png'\n", "if not dest.exists():\n", " data.mkdir(exist_ok=True, parents=True)\n", " untar_dir(archive/tgz, dest)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#TODO: use FastDownload when available" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset has separate folders for the training set and the test set (which we will use for validation). We use the `ls` method to quickly see the contents of a folder in a cross-platform way. Setting `Path.BASE_PATH=dest` causes the `dest` component of paths to not be displayed." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('testing'),Path('training')]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Path.BASE_PATH=dest\n", "dest.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training set is split into folders according to the labels:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#10) [Path('training/0'),Path('training/1'),Path('training/2'),Path('training/3'),Path('training/4'),Path('training/5'),Path('training/6'),Path('training/7'),Path('training/8'),Path('training/9')]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainpath = dest/'training'\n", "testpath = dest/'testing'\n", "trainpath.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `glob` we can get a list of all of the files in the training and test sets, using a [list comprehension](https://docs.python.org/3/tutorial/datastructures.html#list-comprehensions):" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('training/0/1.png')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from glob import glob\n", "def getfiles(p): return [Path(o) for o in sorted(glob(f'{p}/**/*.png', recursive=True))]\n", "\n", "trainfiles = getfiles(trainpath)\n", "testfiles = getfiles(testpath)\n", "fname = trainfiles[0]\n", "fname" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see from the path that this is the number \"8\", and is in the training set. We can open the image file using [torchvision](https://pytorch.org/vision/stable/index.html):" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from torchvision.io import read_image\n", "img = read_image(str(fname))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This returns a `Tensor` (a multidimensional array). We can see its size:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 28, 28])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This shows that it is a single image of size 28x28. We can view it using matplotlib (note that matplotlib expects a 28x28 tensor, so we index into the first dimension to get that):" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOv0lEQVR4nO3df6zV9X3H8deLuysqioFaKKV2VIVa5laot1hnW2xNDbpkaFLbksUy50KTVofVbTVuSU2XLK6xde2K7WilYn9gmqiVNM5KGZmztdQLUkHRYikowmCCm7/xXu57f9yvy1Xv93MO53zPD+7n+Uhuzrnf9/mc7zsHXvd7zvmc7/k4IgRg7BvX6QYAtAdhBzJB2IFMEHYgE4QdyMTvtXNnR3l8HK0J7dwlkJVX9KJejYMerdZU2G0vkPQ1ST2SvhMR16duf7Qm6Eyf28wuASSsj7WltYafxtvukbRM0vmSZktaZHt2o/cHoLWaec0+T9ITEbE9Il6VdJukhdW0BaBqzYR9uqSnRvy+q9j2OraX2O633T+gg03sDkAzmgn7aG8CvOmztxGxPCL6IqKvV+Ob2B2AZjQT9l2SThrx+zsk7W6uHQCt0kzYH5Q00/a7bB8l6VOSVlfTFoCqNTz1FhGDti+X9FMNT72tiIhHKusMQKWammePiLsl3V1RLwBaiI/LApkg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5lo65LNGHsGP3pGsr7ns+VLfv36rJXJse99YHGy/vZlRyXrPes2Juu54cgOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmmGdH0tD8ucn611d8I1k/tbf8v9hQjX0/dNZ3k/XH+w4l638z4wM19pCXpsJue4ek5yUdkjQYEX1VNAWgelUc2T8SEc9UcD8AWojX7EAmmg17SLrX9gbbS0a7ge0ltvtt9w+o/HPSAFqr2afxZ0fEbttTJK2x/VhE3DfyBhGxXNJySZroydHk/gA0qKkje0TsLi73SbpT0rwqmgJQvYbDbnuC7eNfuy7pPElbqmoMQLWaeRo/VdKdtl+7nx9GxD2VdIW2GTgvPVv6tzd9L1mf1Zs+p3woMZu+fWAgOfZ/h8Yn63PTZR08//2ltWPWbU6OHXrllfSdH4EaDntEbJf03gp7AdBCTL0BmSDsQCYIO5AJwg5kgrADmeAU1zGgZ+LE0tqLHz4tOfbzN/4wWf/IMS/U2Hvjx4tbnv3jZH3tTWcl6z+/7uvJ+prvfKu0Nvv7lyfHnvyFB5L1IxFHdiAThB3IBGEHMkHYgUwQdiAThB3IBGEHMsE8+xiw69bppbUH37+sjZ0cni9NeTBZv+e49Dz8pTvOS9ZXzvhZaW3i7P3JsWMRR3YgE4QdyARhBzJB2IFMEHYgE4QdyARhBzLBPPsRYPCjZyTrq+aUL5s8Tumveq7l0p3nJuv9P3tPsr75svLe1r18dHLslP6Xk/Unnk2fq9/7j+tKa+OcHDomcWQHMkHYgUwQdiAThB3IBGEHMkHYgUwQdiATjoi27WyiJ8eZTs/b5mho/txk/Z9X3pSsn9rb+Mcl/vSxi5L1no+/mKwf+JN3J+v7Ty+f0J617Knk2MGndiXrtfzk6Q2ltT2H0nP4f7H4r5L1nnUbG+qp1dbHWj0XB0Z90Gse2W2vsL3P9pYR2ybbXmN7W3E5qcqGAVSvnqfxt0ha8IZt10haGxEzJa0tfgfQxWqGPSLuk3TgDZsXSlpZXF8p6cJq2wJQtUbfoJsaEXskqbicUnZD20ts99vuH9DBBncHoFktfzc+IpZHRF9E9PVqfKt3B6BEo2Hfa3uaJBWX+6prCUArNBr21ZIWF9cXS7qrmnYAtErNCVrbqySdI+lE27skfVHS9ZJ+ZPsySU9KuriVTR7pfMYfJOvPXJWe853Vmz4nfUPirZB/f2F2cuz+205K1t/ybHqd8hO+/8t0PVEbTI5srak96ZeU+698KVmfUn6qfNeqGfaIWFRS4tMxwBGEj8sCmSDsQCYIO5AJwg5kgrADmeCrpCsw7thjk/XBLz+XrP/ytDuS9d8NvpqsX3Xt1aW1Sf/5ZHLslAnpz0MdSlbHrnnTdibrO9rTRqU4sgOZIOxAJgg7kAnCDmSCsAOZIOxAJgg7kAnm2Svw8vz0Kaw/PS39VdC1/OXSzyfrx/+4/DTTTp5Giu7CkR3IBGEHMkHYgUwQdiAThB3IBGEHMkHYgUwwz16BP/qHTcn6uBp/Uy/dmf6i3mN+/KvDbQmSet1TWhuosVJ5j9u3lHm7cGQHMkHYgUwQdiAThB3IBGEHMkHYgUwQdiATzLPX6X8uOau09vdTb0iOHVKNJZfvTS+r/E79IlnH6Aai/FvvhzSUHHvP1vS/yUxtbKinTqp5ZLe9wvY+21tGbLvO9tO2NxU/F7S2TQDNqudp/C2SFoyy/caImFP83F1tWwCqVjPsEXGfpANt6AVACzXzBt3lth8unuZPKruR7SW2+233D+hgE7sD0IxGw/5NSadImiNpj6SvlN0wIpZHRF9E9PVqfIO7A9CshsIeEXsj4lBEDEn6tqR51bYFoGoNhd32tBG/XiRpS9ltAXSHmvPstldJOkfSibZ3SfqipHNsz5EUGl6q+jOta7E7DB5TXjthXHoe/YFX0i9fTr51d3rfyerYVWvd+8duOL3GPWworfzZ9vOTI09b+rtk/Uhct75m2CNi0Sibb25BLwBaiI/LApkg7EAmCDuQCcIOZIKwA5ngFNc22H/ouGR9cPuO9jTSZWpNrT1+/R8m648t/Eay/m8vnVBa273s1OTY458tXwb7SMWRHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTDDP3gZ//fOLk/VZiVMxj3RD8+eW1vZd9XJy7Na+9Dz6uZs/maxPWLC9tHa8xt48ei0c2YFMEHYgE4QdyARhBzJB2IFMEHYgE4QdyATz7PVyeWlcjb+ZX/vgqmR9mWY10lFX2Pml8qWsJen2T3+1tDarN/0V3O/71eJk/e0XPZqs4/U4sgOZIOxAJgg7kAnCDmSCsAOZIOxAJgg7kAnm2esV5aUhDSWHzj9mf7J+5S1nJOunfDd9/73/9Xxpbe/8tybHTv7krmT9ineuTdbPPzZ9Lv7qF6eW1j69eUFy7In/OiFZx+GpeWS3fZLtdba32n7E9tJi+2Tba2xvKy4ntb5dAI2q52n8oKSrI+I9kj4g6XO2Z0u6RtLaiJgpaW3xO4AuVTPsEbEnIjYW15+XtFXSdEkLJa0sbrZS0oUt6hFABQ7rDTrbMyTNlbRe0tSI2CMN/0GQNKVkzBLb/bb7B3SwyXYBNKrusNs+TtLtkq6MiOfqHRcRyyOiLyL6ejW+kR4BVKCusNvu1XDQfxARdxSb99qeVtSnSdrXmhYBVKHm1JttS7pZ0taIGHm+4mpJiyVdX1ze1ZIOx4CjnX6Yt37sW8n6/R86OlnfdvBtpbVLT9iRHNuspbs/lKzf84s5pbWZS/P7OudOqmee/WxJl0jabHtTse1aDYf8R7Yvk/SkpPSXowPoqJphj4j7Vf7VDedW2w6AVuHjskAmCDuQCcIOZIKwA5kg7EAmHJE4d7NiEz05zvSR+QZ+z6xTSmuzVu1Mjv2ntz3Q1L5rfVV1rVNsUx46mL7vRf+xJFmfdenYXW76SLQ+1uq5ODDq7BlHdiAThB3IBGEHMkHYgUwQdiAThB3IBGEHMsFXSdfp0G9+W1rbdvGM5NjZV1yRrD/6iX9ppKW6nHb3Z5P1d9/0UrI+6yHm0ccKjuxAJgg7kAnCDmSCsAOZIOxAJgg7kAnCDmSC89mBMYTz2QEQdiAXhB3IBGEHMkHYgUwQdiAThB3IRM2w2z7J9jrbW20/Yntpsf0620/b3lT8XND6dgE0qp4vrxiUdHVEbLR9vKQNttcUtRsj4obWtQegKvWsz75H0p7i+vO2t0qa3urGAFTrsF6z254haa6k9cWmy20/bHuF7UklY5bY7rfdP6CDzXULoGF1h932cZJul3RlRDwn6ZuSTpE0R8NH/q+MNi4ilkdEX0T09Wp88x0DaEhdYbfdq+Gg/yAi7pCkiNgbEYciYkjStyXNa12bAJpVz7vxlnSzpK0R8dUR26eNuNlFkrZU3x6AqtTzbvzZki6RtNn2pmLbtZIW2Z4jKSTtkPSZFvQHoCL1vBt/v6TRzo+9u/p2ALQKn6ADMkHYgUwQdiAThB3IBGEHMkHYgUwQdiAThB3IBGEHMkHYgUwQdiAThB3IBGEHMkHYgUy0dclm2/8taeeITSdKeqZtDRyebu2tW/uS6K1RVfb2+xHx1tEKbQ37m3Zu90dEX8caSOjW3rq1L4neGtWu3ngaD2SCsAOZ6HTYl3d4/ynd2lu39iXRW6Pa0ltHX7MDaJ9OH9kBtAlhBzLRkbDbXmD7cdtP2L6mEz2Usb3D9uZiGer+DveywvY+21tGbJtse43tbcXlqGvsdai3rljGO7HMeEcfu04vf9721+y2eyT9RtLHJO2S9KCkRRHxaFsbKWF7h6S+iOj4BzBsf1jSC5JujYjTi21flnQgIq4v/lBOiogvdElv10l6odPLeBerFU0bucy4pAsl/bk6+Ngl+vqE2vC4deLIPk/SExGxPSJelXSbpIUd6KPrRcR9kg68YfNCSSuL6ys1/J+l7Up66woRsSciNhbXn5f02jLjHX3sEn21RSfCPl3SUyN+36XuWu89JN1re4PtJZ1uZhRTI2KPNPyfR9KUDvfzRjWX8W6nNywz3jWPXSPLnzerE2EfbSmpbpr/Ozsi3ifpfEmfK56uoj51LePdLqMsM94VGl3+vFmdCPsuSSeN+P0dknZ3oI9RRcTu4nKfpDvVfUtR731tBd3icl+H+/l/3bSM92jLjKsLHrtOLn/eibA/KGmm7XfZPkrSpySt7kAfb2J7QvHGiWxPkHSeum8p6tWSFhfXF0u6q4O9vE63LONdtsy4OvzYdXz584ho+4+kCzT8jvxvJf1dJ3oo6etkSb8ufh7pdG+SVmn4ad2Ahp8RXSbpLZLWStpWXE7uot6+J2mzpIc1HKxpHertgxp+afiwpE3FzwWdfuwSfbXlcePjskAm+AQdkAnCDmSCsAOZIOxAJgg7kAnCDmSCsAOZ+D/cBlFxmLMWWwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.imshow(img[0]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a linear model, we'll need independent variables (the pixel values) and a dependent variable (the labels).\n", "\n", "We need to flatten the rank-3, 1x28x28 shaped tensor into a vector (i.e a length 28 rank-1 tensor). Reminder: *rank* refers to the number of dimensions in a tensor, and is equal to the length of a tensor's shape vector.\n", "\n", "We can use `view(-1)` for this purpose:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([784])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def x_func(p): return read_image(str(p)).view(-1)\n", "\n", "x_func(fname).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our labels need to be numbers or tensors of numbers. For MNIST, we can use `int` to create integers from the parent folder names:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def y_func(p): return int(p.parent.name)\n", " \n", "y_func(fname)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For any index, we can now return a tuple of input and label:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def xy_func(i, is_train=True):\n", " fnames = trainfiles if is_train else testfiles\n", " fname = fnames[i]\n", " return x_func(fname),y_func(fname)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784]), 0)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = xy_func(0)\n", "x.shape,y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear model forward pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Anything which supports `len` and returns a tuple of numbers (or tensors of numbers) when indexed can be used for defining independent and dependent variables in PyTorch. In addition, PyTorch expects datasets to inherit from the [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) abstract class, although that doesn't actually provide any functionality.\n", "\n", "We probably don't want to create a list of all the image tensors and labels, because that will take a lot of time to create, and will take up a lot of memory. Instead, we want something that supports `len` and indexing, but that only actually reads the image when we access it.\n", "\n", "We can do that by creating a class which supports special [dunder](https://www.geeksforgeeks.org/dunder-magic-methods-python/) methods. All dunder methods start and end in `__`. Python supports many dunder methods. The ones we'll need are `__len__`, which is used by `len`, and `__getitem__`, which is used when indexing. We'll also need to use `__init__` which is used for initializing a new object. Here's our class:" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import Dataset\n", "\n", "class MnistDataset(Dataset):\n", " def __init__(self, paths): self.paths = paths\n", " def __len__(self): return len(self.paths)\n", " \n", " def __getitem__(self, i):\n", " path = self.paths[i]\n", " return x_func(path), y_func(path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now create our datasets and test them out:" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60000, 10000)" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds,test_ds = MnistDataset(trainfiles),MnistDataset(testfiles)\n", "len(train_ds),len(test_ds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The length is working correctly. Let's also test the indexer:" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([784]), 0)" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = train_ds[0]\n", "x.shape,y" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }