{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The watermark extension is already loaded. To reload it, use:\n", " %reload_ext watermark\n", "Author: Sebastian Raschka\n", "\n", "Python implementation: CPython\n", "Python version : 3.8.8\n", "IPython version : 7.30.1\n", "\n", "torch: 1.10.1\n", "numpy: 1.22.2\n", "\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch,numpy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Using PyTorch Dataset Loading Utilities for Custom Datasets (Images from Quickdraw)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook provides an example for how to load an image dataset, stored as individual PNG files, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official\n", "\n", "- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n", "- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n", "\n", "In this example, we are using the Quickdraw dataset consisting of handdrawn objects, which is available at https://quickdraw.withgoogle.com. \n", "\n", "To execute the following examples, you need to download the \".npy\" (bitmap files in NumPy). You don't need to download all of the 345 categories but only a subset you are interested in. The groups/subsets can be individually downloaded from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap\n", "\n", "Unfortunately, the Google cloud storage currently does not support selecting and downloading multiple groups at once. Thus, in order to download all groups most coneniently, we need to use their `gsutil` (https://cloud.google.com/storage/docs/gsutil_install) tool. If you want to install that, you can then use \n", "\n", " mkdir quickdraw-npy\n", " gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy\n", "\n", "Note that if you download the whole dataset, this will take up 37 Gb of storage space.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import os\n", "\n", "import torch\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After downloading the dataset to a local directory, `quickdraw-npy`, the next step is to select certain groups we are interested in analyzing. Let's say we are interested in the following groups defined in the `label_dict` in the next code cell:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label_dict = {\n", " \"lollipop\": 0,\n", " \"binoculars\": 1,\n", " \"mouse\": 2,\n", " \"basket\": 3,\n", " \"penguin\": 4,\n", " \"washing machine\": 5,\n", " \"canoe\": 6,\n", " \"eyeglasses\": 7,\n", " \"beach\": 8,\n", " \"screwdriver\": 9,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dictionary values shall represent class labels that we could use for a classification task, for example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conversion to PNG files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we are going to convert the groups we are interested in (specified in the dictionary above) to individual PNG files using a helper function (note that this might take a while):" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# load utilities from ../helper.py\n", "import sys\n", "sys.path.insert(0, '..') \n", "from helper import quickdraw_npy_to_imagefile\n", "\n", " \n", "quickdraw_npy_to_imagefile(inpath='quickdraw-npy',\n", " outpath='quickdraw-png_set1',\n", " subset=label_dict.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocessing into train/valid/test subsets and creating a label files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For convenience, let's create a CSV file mapping file names to class labels. First, let's collect the files and labels." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num paths: 1515745\n", "Num labels: 1515745\n" ] } ], "source": [ "paths, labels = [], []\n", "\n", "main_dir = 'quickdraw-png_set1/'\n", "\n", "for d in os.listdir(main_dir):\n", " subdir = os.path.join(main_dir, d)\n", " if not os.path.isdir(subdir):\n", " continue\n", " for f in os.listdir(subdir):\n", " path = os.path.join(d, f)\n", " paths.append(path)\n", " labels.append(label_dict[d])\n", " \n", "print('Num paths:', len(paths))\n", "print('Num labels:', len(labels))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we shuffle the dataset and assign 70% of the dataset for training, 10% for validation, and 20% for testing." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from mlxtend.preprocessing import shuffle_arrays_unison\n", "\n", "\n", "paths2, labels2 = shuffle_arrays_unison(arrays=[np.array(paths), np.array(labels)], random_seed=3)\n", "\n", "\n", "cut1 = int(len(paths)*0.7)\n", "cut2 = int(len(paths)*0.8)\n", "\n", "paths_train, labels_train = paths2[:cut1], labels2[:cut1]\n", "paths_valid, labels_valid = paths2[cut1:cut2], labels2[cut1:cut2]\n", "paths_test, labels_test = paths2[cut2:], labels2[cut2:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let us create a CSV file that maps the file paths to the class labels:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Label
Path
penguin/penguin_112865.png4
binoculars/binoculars_040058.png1
eyeglasses/eyeglasses_208525.png7
basket/basket_050889.png3
lollipop/lollipop_037650.png0
\n", "
" ], "text/plain": [ " Label\n", "Path \n", "penguin/penguin_112865.png 4\n", "binoculars/binoculars_040058.png 1\n", "eyeglasses/eyeglasses_208525.png 7\n", "basket/basket_050889.png 3\n", "lollipop/lollipop_037650.png 0" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame(\n", " {'Path': paths_train,\n", " 'Label': labels_train,\n", " })\n", "\n", "df = df.set_index('Path')\n", "df.to_csv('quickdraw_png_set1_train.csv')\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Label
Path
basket/basket_046788.png3
basket/basket_074810.png3
binoculars/binoculars_077567.png1
screwdriver/screwdriver_007042.png9
binoculars/binoculars_037327.png1
\n", "
" ], "text/plain": [ " Label\n", "Path \n", "basket/basket_046788.png 3\n", "basket/basket_074810.png 3\n", "binoculars/binoculars_077567.png 1\n", "screwdriver/screwdriver_007042.png 9\n", "binoculars/binoculars_037327.png 1" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame(\n", " {'Path': paths_valid,\n", " 'Label': labels_valid,\n", " })\n", "\n", "df = df.set_index('Path')\n", "df.to_csv('quickdraw_png_set1_valid.csv')\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Label
Path
eyeglasses/eyeglasses_051135.png7
mouse/mouse_059179.png2
basket/basket_046866.png3
penguin/penguin_043578.png4
screwdriver/screwdriver_081750.png9
\n", "
" ], "text/plain": [ " Label\n", "Path \n", "eyeglasses/eyeglasses_051135.png 7\n", "mouse/mouse_059179.png 2\n", "basket/basket_046866.png 3\n", "penguin/penguin_043578.png 4\n", "screwdriver/screwdriver_081750.png 9" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame(\n", " {'Path': paths_test,\n", " 'Label': labels_test,\n", " })\n", "\n", "df = df.set_index('Path')\n", "df.to_csv('quickdraw_png_set1_test.csv')\n", "\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's open one of the images to make sure they look ok:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQfklEQVR4nO3df4xV5Z3H8c9XyoCiRFkGllCVohDxB0vhBoluRK1bhcQoJNViaFghC/4imhiyin9UQ2JkFWujUKWCpcJSSgRFo7slKDH1D8JVQJmdgIhjSyEwKLEYjTDDd/+YazrinOcM59xf4/N+JZM7c7/3ueeby3w4d+5zznnM3QXg+++0WjcAoDoIOxAJwg5EgrADkSDsQCR+UM2NDRw40IcNG1bNTQJRaWlp0eHDh62rWq6wm9kNkn4tqZek5939sdDjhw0bpmKxmGeTAAIKhUJiLfPbeDPrJWmxpEmSLpY0zcwuzvp8ACorz9/s4yXtcfe97n5M0h8k3VSetgCUW56wD5X0104/7yvd9y1mNtvMimZWbG1tzbE5AHnkCXtXHwJ859hbd1/q7gV3LzQ2NubYHIA88oR9n6RzO/38Q0n787UDoFLyhH2rpBFm9iMza5D0c0kbytMWgHLLPPXm7m1mdo+k/1XH1Ntyd28qW2dVdvTo0WD9rLPOqlInQGXkmmd399clvV6mXgBUEIfLApEg7EAkCDsQCcIORIKwA5Eg7EAkqno+e15tbW2JtSVLlgTHPv/888H6rl27gvW33norsXbFFVcEx6LnSTvuoqkpfEjJxReHTwDt37//KfeUF3t2IBKEHYgEYQciQdiBSBB2IBKEHYhEj5p6e/XVVxNr9957b3DsVVddFayff/75wfq0adMSa9u2bQuOHTBgQLCO6ps6dWqw/vLLLwfraQuiTpgwIVgPTeX27ds3ODYr9uxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSiR82z79ixI7HW0NAQHLtx48ZgPe0U1/HjxyfWZs2aFRy7bt26YN2syxV2kdPWrVsTa+vXrw+OveOOO4L1Sy65JFhPO+5j5syZibVVq1YFx2b9fWHPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPPvu3bsTayNGjAiOTZuHv+yyy4L1p556KrGWNif7zDPPBOtz584N1pHN4sWLE2tnn312cOyiRYuC9TPOOCNYb29vD9bvu+++xNrkyZODY6dPnx6sJ8kVdjNrkXRUUrukNncv5Hk+AJVTjj37Ne5+uAzPA6CC+JsdiETesLukP5nZu2Y2u6sHmNlsMyuaWbG1tTXn5gBklTfsV7r7WEmTJN1tZt+5qqO7L3X3grsXGhsbc24OQFa5wu7u+0u3hyStl5R8ahiAmsocdjPrZ2ZnffO9pJ9K2lmuxgCUV55P4wdLWl86t/YHkv7b3f+nLF0lCC2jmzZvmtecOXMSa5s2bQqOnTdvXrB+/fXXB+sjR44M1mN1+HB4EmjNmjWJtTvvvDM4Nm0ePU3a+eyhefaWlpZc206SOezuvlfSv5SxFwAVxNQbEAnCDkSCsAORIOxAJAg7EIkedYpr6BK6X331VRU7+bZnn302WA8tNS1Jy5YtC9YXLlx4yj3FYOXKlcH6119/nVi76667yt1O3WPPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPHto2eRHHnkkOPaLL74I1s8888xMPUnSgAEDgvUbb7wxWF+9enWwzjx71955551gPbSs8oUXXljuduoee3YgEoQdiARhByJB2IFIEHYgEoQdiARhByLRo+bZx4wZk1hra2sLjt2zZ0/m585r9OjRwfratWsrtu3vs2KxGKxPnDixSp30DOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRI+aZ+/Vq1fmse5exk5QDUeOHAnW05Y2Di2LHKPUPbuZLTezQ2a2s9N9A8xso5l9WLo9p7JtAsirO2/jfyfphpPue0DSJncfIWlT6WcAdSw17O7+tqTPTrr7JkkrSt+vkHRzedsCUG5ZP6Ab7O4HJKl0OyjpgWY228yKZlZsbW3NuDkAeVX803h3X+ruBXcvNDY2VnpzABJkDftBMxsiSaXbQ+VrCUAlZA37BkkzSt/PkPRKedoBUCmp8+xmtlrS1ZIGmtk+Sb+U9JikP5rZLEl/kfSzSjaJODU1NeUan3Ydgdikht3dpyWUflLmXgBUEIfLApEg7EAkCDsQCcIORIKwA5HoUae4Ii7Hjh3LNf70008vUyffD+zZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IBPPsQAUcP34889jevXuXsZN/YM8ORIKwA5Eg7EAkCDsQCcIORIKwA5Eg7EAkopln/+ijj4L1N998M1jfvXt35m1v27Yt81hJmjNnTq7xldSnT59gfcKECYm1a6+9Njj24MGDmXr6xsCBA3ONz6O5uTnz2JEjR5axk39gzw5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCSimWe/7bbbgvW2trZg/bzzzkusnXZa+P/M/fv3B+tp5y9v3LgxWK+lo0ePButPP/105uceNGhQsJ72uu/cuTPzc/fv3z9YT7Nhw4bMY8eNG5dr20lS9+xmttzMDpnZzk73PWxmfzOz7aWvyRXpDkDZdOdt/O8k3dDF/b9y9zGlr9fL2xaAcksNu7u/LemzKvQCoILyfEB3j5m9X3qbf07Sg8xstpkVzazY2tqaY3MA8sga9t9IukDSGEkHJC1KeqC7L3X3grsXGhsbM24OQF6Zwu7uB9293d1PSPqtpPHlbQtAuWUKu5kN6fTjFEnJcxwA6kLqPLuZrZZ0taSBZrZP0i8lXW1mYyS5pBZJZTnh+tFHHw3WFyxYkPm50+YuBw8eHKw3NDQk1j7//PPg2I8//jhYX7NmTbB+yy23BOv1bPPmzYm1SZMmBcceOXIkWD9x4kSwPmXKlMRa6N9TkqZPnx6sT5s2LVhfuHBh5vGhYzrySA27u3fV1bIK9AKggjhcFogEYQciQdiBSBB2IBKEHYiEuXvVNlYoFLxYLCbWZ86cGRz/wgsvJNbGjh2bua9KGz16dLC+fPnyYN3MytlOVbW3tyfWnnzyyeDYdevWBetbtmwJ1qv5u32yfv36Beu7du1KrA0dOjTzdguFgorFYpe/MOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRF1dSnr48OGZx4bm76WePVfdk/Xq1SuxNm/evODYtHraZc527NgRrIekLbM9f/78YD10eq2Uby49K/bsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5Eoq7m2UeNGpV57KpVq4L1tEsDo+dJW2Houuuuy/zcaWM//fTTYP3xxx8P1h966KHE2kUXXRQcmxV7diAShB2IBGEHIkHYgUgQdiAShB2IBGEHIlFX8+xTp04N1q+55prE2v333x8cO3LkyGB9/PjxwTrQ2a233hqspy3ZHFrGu2bz7GZ2rpm9ZWbNZtZkZveW7h9gZhvN7MPS7TkV6RBAWXTnbXybpPvdfZSkCZLuNrOLJT0gaZO7j5C0qfQzgDqVGnZ3P+Du75W+PyqpWdJQSTdJWlF62ApJN1eoRwBlcEof0JnZMEk/lrRF0mB3PyB1/IcgaVDCmNlmVjSzYto1wwBUTrfDbmZnSnpJ0n3u/vfujnP3pe5ecPdC2okLACqnW2E3s97qCPoqd/9mac2DZjakVB8i6VBlWgRQDqlTb9ZxDeZlkprdvfMauxskzZD0WOn2lbzNpF3u+bnnnkuspZ2SePnllwfraafAzp07N7HGtF18jh07lmv8aadV/xCX7syzXynpF5I+MLPtpfvmqyPkfzSzWZL+IulnFekQQFmkht3d/ywpaZf7k/K2A6BSOFwWiARhByJB2IFIEHYgEoQdiERdneKaZsSIEYm15ubm4NgnnngiWF+0aFGwvnLlysTapZdeGhw7c+bMYD1tjp8jD+vPa6+9Fqw3NDQE6+PGjStnO93Cnh2IBGEHIkHYgUgQdiAShB2IBGEHIkHYgUiYu1dtY4VCwYvFYtW2dyq+/PLLYP2ll15KrC1fvjw4dvPmzcF62pzs7bffHqwvWLAgscYcfTZvvPFGsJ52bMTEiROD9XXr1gXrWRUKBRWLxS7PUmXPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJJhnr4K9e/cG64sXL85V79u3b2LtggsuCI5Nu375qFGjgvW0a+anHUNQSUeOHEmsNTU1Bce++OKLwfrYsWOD9bVr1wbrw4cPD9azYp4dAGEHYkHYgUgQdiAShB2IBGEHIkHYgUikzrOb2bmSfi/pnyWdkLTU3X9tZg9L+g9JraWHznf310PPFes8e15p8/Sha+IfPnw4OPb48ePB+rZt24L1Tz75JFivV3369AnWH3zwwWB9/vz5wXrv3r1PuadyCM2zd2eRiDZJ97v7e2Z2lqR3zWxjqfYrdw+vvgCgLnRnffYDkg6Uvj9qZs2Shla6MQDldUp/s5vZMEk/lrSldNc9Zva+mS03s3MSxsw2s6KZFVtbW7t6CIAq6HbYzexMSS9Jus/d/y7pN5IukDRGHXv+LhdLc/el7l5w9wLXQwNqp1thN7Pe6gj6KndfJ0nuftDd2939hKTfSgqfEQGgplLDbmYmaZmkZnd/stP9Qzo9bIqkneVvD0C5dOfT+Csl/ULSB2a2vXTffEnTzGyMJJfUImlOBfqD0k+HXLJkSZU6+a5Dhw4F6+3t7Ym10Km55RA6vbZfv34V3XY96s6n8X+W1NW8XXBOHUB94Qg6IBKEHYgEYQciQdiBSBB2IBKEHYhEd+bZgUSDBg2qdQvoJvbsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EoqpLNptZq6TO1x4eKCl8rePaqdfe6rUvid6yKmdv57t7l9d/q2rYv7Nxs6K7F2rWQEC99lavfUn0llW1euNtPBAJwg5EotZhX1rj7YfUa2/12pdEb1lVpbea/s0OoHpqvWcHUCWEHYhETcJuZjeY2S4z22NmD9SihyRm1mJmH5jZdjOr6frSpTX0DpnZzk73DTCzjWb2Yem2yzX2atTbw2b2t9Jrt93MJteot3PN7C0zazazJjO7t3R/TV+7QF9Ved2q/je7mfWStFvSv0naJ2mrpGnu/n9VbSSBmbVIKrh7zQ/AMLOrJH0h6ffufmnpvv+S9Jm7P1b6j/Icd//POuntYUlf1HoZ79JqRUM6LzMu6WZJ/64avnaBvm5RFV63WuzZx0va4+573f2YpD9IuqkGfdQ9d39b0mcn3X2TpBWl71eo45el6hJ6qwvufsDd3yt9f1TSN8uM1/S1C/RVFbUI+1BJf+308z7V13rvLulPZvaumc2udTNdGOzuB6SOXx5J9XZdqNRlvKvppGXG6+a1y7L8eV61CHtXS0nV0/zfle4+VtIkSXeX3q6ie7q1jHe1dLHMeF3Iuvx5XrUI+z5J53b6+YeS9tegjy65+/7S7SFJ61V/S1Ef/GYF3dJteGXFKqqnZby7WmZcdfDa1XL581qEfaukEWb2IzNrkPRzSRtq0Md3mFm/0gcnMrN+kn6q+luKeoOkGaXvZ0h6pYa9fEu9LOOdtMy4avza1Xz5c3ev+pekyer4RP4jSQ/VooeEvoZL2lH6aqp1b5JWq+Nt3XF1vCOaJemfJG2S9GHpdkAd9faipA8kva+OYA2pUW//qo4/Dd+XtL30NbnWr12gr6q8bhwuC0SCI+iASBB2IBKEHYgEYQciQdiBSBB2IBKEHYjE/wP8VOMwG0xOiAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "main_dir = 'quickdraw-png_set1/'\n", "\n", "img = Image.open(os.path.join(main_dir, df.index[1]))\n", "img = np.asarray(img, dtype=np.uint8)\n", "print(img.shape)\n", "plt.imshow(np.array(img), cmap='binary');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementing a Custom Dataset Class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we implement a custom `Dataset` for reading the images. The `__getitem__` method will\n", "\n", "1. read a single image from disk based on an `index` (more on batching later)\n", "2. perform a custom image transformation (if a `transform` argument is provided in the `__init__` construtor)\n", "3. return a single image and it's corresponding label" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class QuickdrawDataset(Dataset):\n", " \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n", "\n", " def __init__(self, txt_path, img_dir, transform=None):\n", " \n", " df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n", " self.img_dir = img_dir\n", " self.txt_path = txt_path\n", " self.img_names = df.index.values\n", " self.y = df['Label'].values\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_names[index]))\n", " \n", " if self.transform is not None:\n", " img = self.transform(img)\n", " \n", " label = self.y[index]\n", " return img, label\n", "\n", " def __len__(self):\n", " return self.y.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have created our custom Dataset class, let us add some custom transformations via the `transforms` utilities from `torchvision`, we\n", "\n", "1. normalize the images (here: dividing by 255)\n", "2. converting the image arrays into PyTorch tensors\n", "\n", "Then, we initialize a Dataset instance for the training images using the 'quickdraw_png_set1_train.csv' label file (we omit the test set, but the same concepts apply).\n", "\n", "Finally, we initialize a `DataLoader` that allows us to read from the dataset." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Note that transforms.ToTensor()\n", "# already divides pixels by 255. internally\n", "\n", "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n", " img_dir='quickdraw-png_set1/',\n", " transform=custom_transform)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=128,\n", " shuffle=True,\n", " num_workers=4) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's it, now we can iterate over an epoch using the train_loader as an iterator and use the features and labels from the training dataset for model training:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Iterating Through the Custom Dataset" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Batch index: 0 | Batch size: 128\n", "Epoch: 2 | Batch index: 0 | Batch size: 128\n" ] } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "torch.manual_seed(0)\n", "\n", "num_epochs = 2\n", "for epoch in range(num_epochs):\n", "\n", " for batch_idx, (x, y) in enumerate(train_loader):\n", " \n", " print('Epoch:', epoch+1, end='')\n", " print(' | Batch index:', batch_idx, end='')\n", " print(' | Batch size:', y.size()[0])\n", " \n", " x = x.to(device)\n", " y = y.to(device)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just to make sure that the batches are being loaded correctly, let's print out the dimensions of the last batch:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 1, 28, 28])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, each batch consists of 128 images, just as specified. However, one thing to keep in mind though is that\n", "PyTorch uses a different image layout (which is more efficient when working with CUDA); here, the image axes are \"num_images x channels x height x width\" (NCHW) instead of \"num_images height x width x channels\" (NHWC):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visually check that the images that coming of the data loader are intact, let's swap the axes to NHWC and convert an image from a Torch Tensor to a NumPy array so that we can visualize the image via `imshow`:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([28, 28, 1])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "one_image = x[0].permute(1, 2, 0)\n", "one_image.shape" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQxklEQVR4nO3df2xVZZoH8O9XnBGpw49CkQpIZcQoAivY4CaSiWTYiRIUiQ4OwRGjWfBXdKLRNe4fozEmqDtjJrhO0hlgmKVqMDMKJsalkjFmEjNSoEuLuOCS2kHQtpFfgqLAs3/0OFboec/lnnPvueX5fpLm3t7nvj2PV749t/c957w0M4jIme+svBsQkfJQ2EWcUNhFnFDYRZxQ2EWcOLucGxsxYoTV1dWVc5MirrS3t6O7u5t91VKFneS1AH4DYACA35vZ0tDz6+rq0NzcnGaTIhJQX18fWyv6bTzJAQD+E8B1ACYCWEByYrE/T0RKK83f7NMBfGhmu8zsKwAvA5ibTVsikrU0YR8N4O+9vt8dPfYdJBeTbCbZ3NXVlWJzIpJGmrD39SHAKcfemlmDmdWbWX1NTU2KzYlIGmnCvhvA2F7fjwGwJ107IlIqacK+EcAEkheR/D6AnwFYl01bIpK1oqfezOwYyfsA/Dd6pt5WmNm2zDoTkUylmmc3szcAvJFRLyJSQjpcVsQJhV3ECYVdxAmFXcQJhV3ECYVdxImyns/en4WuwvvQQw8Fx7744ovB+qBBg4L1oUOHButVVVWxtSVLlgTH3nrrrcG6nDm0ZxdxQmEXcUJhF3FCYRdxQmEXcUJhF3FCU28FamxsjK0999xzwbE333xzsF5dXR2s79u3L1h/++23Y2tr1qwJjtXUmx/as4s4obCLOKGwizihsIs4obCLOKGwizihsIs4oXn2Am3cuDG2ljRP/sorr2TdznfMnj07tpY0Ry9+aM8u4oTCLuKEwi7ihMIu4oTCLuKEwi7ihMIu4oSbefYtW7YE6wcPHgzWW1tbY2uTJk0qqqeshOb5Ozo6ytiJVLJUYSfZDuAQgOMAjplZfRZNiUj2stizzzSz7gx+joiUkP5mF3EibdgNwHqSm0gu7usJJBeTbCbZ3NXVlXJzIlKstGG/2symAbgOwL0kf3TyE8yswczqzay+pqYm5eZEpFipwm5me6LbTgCvApieRVMikr2iw06yiuQPvrkP4CcA2rJqTESylebT+PMBvErym5/zopm9mUlXRdi6dWuwPm3atJJtO3oNYg0fPjxYHzJkSLCetGRzZ2dnbK27OzxRknTd+MGDBwfraXpPGpu07TT1tD972LBhwXrS/7OkfzOlUHTYzWwXgH/KsBcRKSFNvYk4obCLOKGwizihsIs4obCLOHHGnOJ66aWXpho/b968YP21116Lrc2ZMyc4dsKECcH6oUOHgvX9+/cH68ePH4+t7dmzJzj2vffeC9a/+OKLYD3p1OCk+plq4cKFwfrq1avL1Mm3tGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXceKMmWc/evRoqvEzZ84M1pcuXRpbGz9+fHDs2Wfn9zIfO3YsWM+zt6TlpNPO4R84cCC2lnRsQ9LPXr58ebC+fv36YD0P2rOLOKGwizihsIs4obCLOKGwizihsIs4obCLOHHGzLMfOXIk1fiBAwcG65dcckmqn5+XPOfRkyRdjjmpnqfDhw8H601NTcF6e3t7bK2urq6IjpJpzy7ihMIu4oTCLuKEwi7ihMIu4oTCLuKEwi7iROVOwp6mDz74INX4/jqPntbnn38erCdds76qqipYDy19PGDAgODYSjZ9+vRU40PX689tnp3kCpKdJNt6PVZNsonkzui2co9+EBEAhb2N/wOAa0967FEAG8xsAoAN0fciUsESw25m7wD47KSH5wJYFd1fBeDGbNsSkawV+wHd+Wa2FwCi25FxTyS5mGQzyeaurq4iNyciaZX803gzazCzejOrr6mpKfXmRCRGsWH/lGQtAES3ndm1JCKlUGzY1wFYFN1fBGBtNu2ISKkkzrOTfAnANQBGkNwN4JcAlgJYQ/JOAB0AflrKJguxZcuWVOOnTJmSUSfZO3HiRLC+dm3879oXXnghOPatt94qqqdCzZo1K7a2cuXK4NgxY8Zk3U5mLrvssmA96fiDzZs3x9bmz59fVE9JEsNuZgtiSj/OuBcRKSEdLivihMIu4oTCLuKEwi7ihMIu4sQZc4pr0tTbRRddFKznednipOWBb7nllmD9zTffjK1dfPHFwbFPPfVUsF5dXR2sJy1N/Prrr8fWxo0bFxyb9N+9bNmyYH348OHBehpJp+decMEFwXpnZ/mPQ9OeXcQJhV3ECYVdxAmFXcQJhV3ECYVdxAmFXcQJN/PsU6dOLVMnp+ro6AjW58yZE6zv2LEjWF+xYkVsbdGiRbE1ADjrrHS/7++6665g/eOPP46tNTQ0BMc+88wzwfquXbuC9XfffTe2RjI4Nq2k4xM+++zkyzqWnvbsIk4o7CJOKOwiTijsIk4o7CJOKOwiTijsIk70q3n2o0ePxtaSlmwu1eV5AcDMgvWZM2cG60nnsydd7nnGjBnBep5Gjx4dW3viiSeCY5OWLr7jjjuC9ba2ttja5MmTg2PT0jy7iORGYRdxQmEXcUJhF3FCYRdxQmEXcUJhF3GiX82zt7a2xta+/vrr4Nhp06Zl3c4/tLS0BOtJ5103NjYG65U8j15Ko0aNSjX+2LFjGXVy+pLm2ZOucVAKiXt2kitIdpJs6/XY4yQ/JtkSfc0ubZsiklYhb+P/AODaPh5/zsyuiL7eyLYtEclaYtjN7B0A5T+2T0QyleYDuvtIbo3e5sculEZyMclmks1dXV0pNiciaRQb9t8C+CGAKwDsBfCruCeaWYOZ1ZtZfU1NTZGbE5G0igq7mX1qZsfN7ASA3wGYnm1bIpK1osJOsrbXt/MAxJ9LKCIVIXGeneRLAK4BMILkbgC/BHANySsAGIB2AEtK1+K3kq4NH1LK68YnzakOGTIkWE9aZ/ymm24K1s8555xgvb/68ssvU40fOHBgRp2cvko8nz0x7Ga2oI+Hl5egFxEpIR0uK+KEwi7ihMIu4oTCLuKEwi7iRL86xTU09TZy5Mjg2Nra2mA9jXHjxgXrq1evDtZvuOGGYP22224L1leuXBlbGzRoUHBsJevPU2/DhsUeQQ5Al5IWkRJS2EWcUNhFnFDYRZxQ2EWcUNhFnFDYRZw4Y+bZS3mp6LTmzJkTrD///PPB+gMPPBCshy6x/fLLLwfHTpkyJVjPU2iJ7kLkeepv0jx76L/tyJEjwbHFHjuhPbuIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIExU1z378+PFgPTSffP/992fdTtncc889wXrSXPjChQtja1dddVVw7JNPPhmsJ/VWyvPl+/P57EmXkg5JOtdd8+wiEqSwizihsIs4obCLOKGwizihsIs4obCLOFFR8+w7duwI1g8fPhxbK+WSzHmbMWNGsN7S0hJbW7IkvJr2ww8/HKw//fTTwfrtt98erI8dOza2ljQP3tTUFKwnOffcc1ONT6OU8+xjxowp6ucm7tlJjiX5F5LbSW4j+UD0eDXJJpI7o9vw2foikqtC3sYfA/CQmV0G4J8B3EtyIoBHAWwwswkANkTfi0iFSgy7me01s83R/UMAtgMYDWAugFXR01YBuLFEPYpIBk7rAzqSdQCmAvgbgPPNbC/Q8wsBQJ+LrZFcTLKZZHNXV1fKdkWkWAWHneR5AP4E4BdmdrDQcWbWYGb1ZlZfU1NTTI8ikoGCwk7ye+gJeqOZ/Tl6+FOStVG9FkBnaVoUkSwkTr2RJIDlALab2a97ldYBWARgaXS7Nm0zoUtFJ0k6DTRpOiPNVEneQpctXrNmTXDspk2bgvVnn302WF+2bFmwnvZy0CEPPvhgsN5fp9727duXYSffKmSe/WoAPwfQSrIleuwx9IR8Dck7AXQA+GlJOhSRTCSG3cz+CoAx5R9n246IlIoOlxVxQmEXcUJhF3FCYRdxQmEXcaKiTnE9++xwO5dffnlsbd26dcGxjzzySLA+cmSfR/sWtO1JkyYFxybVJ06cmGr80KFDg/WQK6+8MlhPWvI5Sei05K+++io4dsCAAcH64MGDi+qpHEp5imuxtGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXcaKi5tnnz59fdL27uzs49rzzzgvWQ8tBA8D7778fW2tsbAyOLdW86TdClxYOHR8AAJMnTw7W0x4DMH78+Nha0lx0z6UU+ifNs4tIbhR2EScUdhEnFHYRJxR2EScUdhEnFHYRJypqnj2NESNGBOt33313mTo51SeffBKst7W1Bevbtm0rup50/EBDQ0OwfvBgwYv/ZC7pPP2keuh6+sOHDy96LJA8j15VVRWsh+zfv7/osSHas4s4obCLOKGwizihsIs4obCLOKGwizihsIs4Ucj67GMB/BHAKAAnADSY2W9IPg7gXwF0RU99zMzeKFWj/dmoUaNS1WfNmpVlO6elo6MjWA+d5w8AH330UWwtaR3ypHrSed+hetLP3rlzZ7CeNP7AgQPB+tSpU2Nr119/fXBssQo5qOYYgIfMbDPJHwDYRLIpqj1nZv9Rks5EJFOFrM++F8De6P4hktsBjC51YyKSrdP6m51kHYCpAP4WPXQfya0kV5Ds8/hCkotJNpNs7urq6uspIlIGBYed5HkA/gTgF2Z2EMBvAfwQwBXo2fP/qq9xZtZgZvVmVl9TU5O+YxEpSkFhJ/k99AS90cz+DABm9qmZHTezEwB+B2B66doUkbQSw86eS3wuB7DdzH7d6/HaXk+bByB86paI5KqQT+OvBvBzAK0kW6LHHgOwgOQVAAxAO4AlJehPcnbhhRemqkvlKOTT+L8C6OsC3ppTF+lHdASdiBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTNLPybYzsAtD72sIjAHSXrYHTU6m9VWpfgHorVpa9jTOzPq//Vtawn7JxstnM6nNrIKBSe6vUvgD1Vqxy9aa38SJOKOwiTuQd9oactx9Sqb1Val+AeitWWXrL9W92ESmfvPfsIlImCruIE7mEneS1JP+X5IckH82jhzgk20m2kmwh2ZxzLytIdpJs6/VYNckmkjuj2z7X2Mupt8dJfhy9di0kZ+fU21iSfyG5neQ2kg9Ej+f62gX6KsvrVva/2UkOALADwL8A2A1gI4AFZhZe6LtMSLYDqDez3A/AIPkjAJ8D+KOZTYoeewbAZ2a2NPpFOczM/q1CenscwOd5L+MdrVZU23uZcQA3ArgdOb52gb7mowyvWx579ukAPjSzXWb2FYCXAczNoY+KZ2bvAPjspIfnAlgV3V+Fnn8sZRfTW0Uws71mtjm6fwjAN8uM5/raBfoqizzCPhrA33t9vxuVtd67AVhPchPJxXk304fzzWwv0POPB8DInPs5WeIy3uV00jLjFfPaFbP8eVp5hL2vpaQqaf7vajObBuA6APdGb1elMAUt410ufSwzXhGKXf48rTzCvhvA2F7fjwGwJ4c++mRme6LbTgCvovKWov70mxV0o9vOnPv5h0paxruvZcZRAa9dnsuf5xH2jQAmkLyI5PcB/AzAuhz6OAXJquiDE5CsAvATVN5S1OsALIruLwKwNsdevqNSlvGOW2YcOb92uS9/bmZl/wIwGz2fyP8fgH/Po4eYvsYD+J/oa1vevQF4CT1v675GzzuiOwEMB7ABwM7otrqCevsvAK0AtqInWLU59TYDPX8abgXQEn3Nzvu1C/RVltdNh8uKOKEj6EScUNhFnFDYRZxQ2EWcUNhFnFDYRZxQ2EWc+H9MFRFikCzOSwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# note that imshow also works fine with scaled\n", "# images in [0, 1] range.\n", "plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='binary');" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch : 1.10.1\n", "PIL : 9.0.1\n", "sys : 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) \n", "[GCC 9.3.0]\n", "torchvision: 0.11.2\n", "pandas : 1.2.5\n", "numpy : 1.22.2\n", "matplotlib : 3.3.4\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }