{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
"- Author: Sebastian Raschka\n",
"- GitHub Repository: https://github.com/rasbt/deeplearning-models"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The watermark extension is already loaded. To reload it, use:\n",
" %reload_ext watermark\n",
"Author: Sebastian Raschka\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.8.8\n",
"IPython version : 7.30.1\n",
"\n",
"torch: 1.10.1\n",
"numpy: 1.22.2\n",
"\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -a 'Sebastian Raschka' -v -p torch,numpy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Zoo -- Using PyTorch Dataset Loading Utilities for Custom Datasets (Images from Quickdraw)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook provides an example for how to load an image dataset, stored as individual PNG files, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official\n",
"\n",
"- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n",
"- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n",
"\n",
"In this example, we are using the Quickdraw dataset consisting of handdrawn objects, which is available at https://quickdraw.withgoogle.com. \n",
"\n",
"To execute the following examples, you need to download the \".npy\" (bitmap files in NumPy). You don't need to download all of the 345 categories but only a subset you are interested in. The groups/subsets can be individually downloaded from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap\n",
"\n",
"Unfortunately, the Google cloud storage currently does not support selecting and downloading multiple groups at once. Thus, in order to download all groups most coneniently, we need to use their `gsutil` (https://cloud.google.com/storage/docs/gsutil_install) tool. If you want to install that, you can then use \n",
"\n",
" mkdir quickdraw-npy\n",
" gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy\n",
"\n",
"Note that if you download the whole dataset, this will take up 37 Gb of storage space.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"\n",
"import torch\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import transforms\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After downloading the dataset to a local directory, `quickdraw-npy`, the next step is to select certain groups we are interested in analyzing. Let's say we are interested in the following groups defined in the `label_dict` in the next code cell:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"label_dict = {\n",
" \"lollipop\": 0,\n",
" \"binoculars\": 1,\n",
" \"mouse\": 2,\n",
" \"basket\": 3,\n",
" \"penguin\": 4,\n",
" \"washing machine\": 5,\n",
" \"canoe\": 6,\n",
" \"eyeglasses\": 7,\n",
" \"beach\": 8,\n",
" \"screwdriver\": 9,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dictionary values shall represent class labels that we could use for a classification task, for example."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Conversion to PNG files"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we are going to convert the groups we are interested in (specified in the dictionary above) to individual PNG files using a helper function (note that this might take a while):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# load utilities from ../helper.py\n",
"import sys\n",
"sys.path.insert(0, '..') \n",
"from helper import quickdraw_npy_to_imagefile\n",
"\n",
" \n",
"quickdraw_npy_to_imagefile(inpath='quickdraw-npy',\n",
" outpath='quickdraw-png_set1',\n",
" subset=label_dict.keys())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preprocessing into train/valid/test subsets and creating a label files"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For convenience, let's create a CSV file mapping file names to class labels. First, let's collect the files and labels."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Num paths: 1515745\n",
"Num labels: 1515745\n"
]
}
],
"source": [
"paths, labels = [], []\n",
"\n",
"main_dir = 'quickdraw-png_set1/'\n",
"\n",
"for d in os.listdir(main_dir):\n",
" subdir = os.path.join(main_dir, d)\n",
" if not os.path.isdir(subdir):\n",
" continue\n",
" for f in os.listdir(subdir):\n",
" path = os.path.join(d, f)\n",
" paths.append(path)\n",
" labels.append(label_dict[d])\n",
" \n",
"print('Num paths:', len(paths))\n",
"print('Num labels:', len(labels))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we shuffle the dataset and assign 70% of the dataset for training, 10% for validation, and 20% for testing."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from mlxtend.preprocessing import shuffle_arrays_unison\n",
"\n",
"\n",
"paths2, labels2 = shuffle_arrays_unison(arrays=[np.array(paths), np.array(labels)], random_seed=3)\n",
"\n",
"\n",
"cut1 = int(len(paths)*0.7)\n",
"cut2 = int(len(paths)*0.8)\n",
"\n",
"paths_train, labels_train = paths2[:cut1], labels2[:cut1]\n",
"paths_valid, labels_valid = paths2[cut1:cut2], labels2[cut1:cut2]\n",
"paths_test, labels_test = paths2[cut2:], labels2[cut2:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let us create a CSV file that maps the file paths to the class labels:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Label \n",
" \n",
" \n",
" Path \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" penguin/penguin_112865.png \n",
" 4 \n",
" \n",
" \n",
" binoculars/binoculars_040058.png \n",
" 1 \n",
" \n",
" \n",
" eyeglasses/eyeglasses_208525.png \n",
" 7 \n",
" \n",
" \n",
" basket/basket_050889.png \n",
" 3 \n",
" \n",
" \n",
" lollipop/lollipop_037650.png \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Label\n",
"Path \n",
"penguin/penguin_112865.png 4\n",
"binoculars/binoculars_040058.png 1\n",
"eyeglasses/eyeglasses_208525.png 7\n",
"basket/basket_050889.png 3\n",
"lollipop/lollipop_037650.png 0"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(\n",
" {'Path': paths_train,\n",
" 'Label': labels_train,\n",
" })\n",
"\n",
"df = df.set_index('Path')\n",
"df.to_csv('quickdraw_png_set1_train.csv')\n",
"\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Label \n",
" \n",
" \n",
" Path \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" basket/basket_046788.png \n",
" 3 \n",
" \n",
" \n",
" basket/basket_074810.png \n",
" 3 \n",
" \n",
" \n",
" binoculars/binoculars_077567.png \n",
" 1 \n",
" \n",
" \n",
" screwdriver/screwdriver_007042.png \n",
" 9 \n",
" \n",
" \n",
" binoculars/binoculars_037327.png \n",
" 1 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Label\n",
"Path \n",
"basket/basket_046788.png 3\n",
"basket/basket_074810.png 3\n",
"binoculars/binoculars_077567.png 1\n",
"screwdriver/screwdriver_007042.png 9\n",
"binoculars/binoculars_037327.png 1"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(\n",
" {'Path': paths_valid,\n",
" 'Label': labels_valid,\n",
" })\n",
"\n",
"df = df.set_index('Path')\n",
"df.to_csv('quickdraw_png_set1_valid.csv')\n",
"\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Label \n",
" \n",
" \n",
" Path \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" eyeglasses/eyeglasses_051135.png \n",
" 7 \n",
" \n",
" \n",
" mouse/mouse_059179.png \n",
" 2 \n",
" \n",
" \n",
" basket/basket_046866.png \n",
" 3 \n",
" \n",
" \n",
" penguin/penguin_043578.png \n",
" 4 \n",
" \n",
" \n",
" screwdriver/screwdriver_081750.png \n",
" 9 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Label\n",
"Path \n",
"eyeglasses/eyeglasses_051135.png 7\n",
"mouse/mouse_059179.png 2\n",
"basket/basket_046866.png 3\n",
"penguin/penguin_043578.png 4\n",
"screwdriver/screwdriver_081750.png 9"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(\n",
" {'Path': paths_test,\n",
" 'Label': labels_test,\n",
" })\n",
"\n",
"df = df.set_index('Path')\n",
"df.to_csv('quickdraw_png_set1_test.csv')\n",
"\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's open one of the images to make sure they look ok:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(28, 28)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQfklEQVR4nO3df4xV5Z3H8c9XyoCiRFkGllCVohDxB0vhBoluRK1bhcQoJNViaFghC/4imhiyin9UQ2JkFWujUKWCpcJSSgRFo7slKDH1D8JVQJmdgIhjSyEwKLEYjTDDd/+YazrinOcM59xf4/N+JZM7c7/3ueeby3w4d+5zznnM3QXg+++0WjcAoDoIOxAJwg5EgrADkSDsQCR+UM2NDRw40IcNG1bNTQJRaWlp0eHDh62rWq6wm9kNkn4tqZek5939sdDjhw0bpmKxmGeTAAIKhUJiLfPbeDPrJWmxpEmSLpY0zcwuzvp8ACorz9/s4yXtcfe97n5M0h8k3VSetgCUW56wD5X0104/7yvd9y1mNtvMimZWbG1tzbE5AHnkCXtXHwJ859hbd1/q7gV3LzQ2NubYHIA88oR9n6RzO/38Q0n787UDoFLyhH2rpBFm9iMza5D0c0kbytMWgHLLPPXm7m1mdo+k/1XH1Ntyd28qW2dVdvTo0WD9rLPOqlInQGXkmmd399clvV6mXgBUEIfLApEg7EAkCDsQCcIORIKwA5Eg7EAkqno+e15tbW2JtSVLlgTHPv/888H6rl27gvW33norsXbFFVcEx6LnSTvuoqkpfEjJxReHTwDt37//KfeUF3t2IBKEHYgEYQciQdiBSBB2IBKEHYhEj5p6e/XVVxNr9957b3DsVVddFayff/75wfq0adMSa9u2bQuOHTBgQLCO6ps6dWqw/vLLLwfraQuiTpgwIVgPTeX27ds3ODYr9uxAJAg7EAnCDkSCsAORIOxAJAg7EAnCDkSiR82z79ixI7HW0NAQHLtx48ZgPe0U1/HjxyfWZs2aFRy7bt26YN2syxV2kdPWrVsTa+vXrw+OveOOO4L1Sy65JFhPO+5j5syZibVVq1YFx2b9fWHPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPPvu3bsTayNGjAiOTZuHv+yyy4L1p556KrGWNif7zDPPBOtz584N1pHN4sWLE2tnn312cOyiRYuC9TPOOCNYb29vD9bvu+++xNrkyZODY6dPnx6sJ8kVdjNrkXRUUrukNncv5Hk+AJVTjj37Ne5+uAzPA6CC+JsdiETesLukP5nZu2Y2u6sHmNlsMyuaWbG1tTXn5gBklTfsV7r7WEmTJN1tZt+5qqO7L3X3grsXGhsbc24OQFa5wu7u+0u3hyStl5R8ahiAmsocdjPrZ2ZnffO9pJ9K2lmuxgCUV55P4wdLWl86t/YHkv7b3f+nLF0lCC2jmzZvmtecOXMSa5s2bQqOnTdvXrB+/fXXB+sjR44M1mN1+HB4EmjNmjWJtTvvvDM4Nm0ePU3a+eyhefaWlpZc206SOezuvlfSv5SxFwAVxNQbEAnCDkSCsAORIOxAJAg7EIkedYpr6BK6X331VRU7+bZnn302WA8tNS1Jy5YtC9YXLlx4yj3FYOXKlcH6119/nVi76667yt1O3WPPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJHrUPHto2eRHHnkkOPaLL74I1s8888xMPUnSgAEDgvUbb7wxWF+9enWwzjx71955551gPbSs8oUXXljuduoee3YgEoQdiARhByJB2IFIEHYgEoQdiARhByLRo+bZx4wZk1hra2sLjt2zZ0/m585r9OjRwfratWsrtu3vs2KxGKxPnDixSp30DOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRI+aZ+/Vq1fmse5exk5QDUeOHAnW05Y2Di2LHKPUPbuZLTezQ2a2s9N9A8xso5l9WLo9p7JtAsirO2/jfyfphpPue0DSJncfIWlT6WcAdSw17O7+tqTPTrr7JkkrSt+vkHRzedsCUG5ZP6Ab7O4HJKl0OyjpgWY228yKZlZsbW3NuDkAeVX803h3X+ruBXcvNDY2VnpzABJkDftBMxsiSaXbQ+VrCUAlZA37BkkzSt/PkPRKedoBUCmp8+xmtlrS1ZIGmtk+Sb+U9JikP5rZLEl/kfSzSjaJODU1NeUan3Ydgdikht3dpyWUflLmXgBUEIfLApEg7EAkCDsQCcIORIKwA5HoUae4Ii7Hjh3LNf70008vUyffD+zZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IBPPsQAUcP34889jevXuXsZN/YM8ORIKwA5Eg7EAkCDsQCcIORIKwA5Eg7EAkopln/+ijj4L1N998M1jfvXt35m1v27Yt81hJmjNnTq7xldSnT59gfcKECYm1a6+9Njj24MGDmXr6xsCBA3ONz6O5uTnz2JEjR5axk39gzw5EgrADkSDsQCQIOxAJwg5EgrADkSDsQCSimWe/7bbbgvW2trZg/bzzzkusnXZa+P/M/fv3B+tp5y9v3LgxWK+lo0ePButPP/105uceNGhQsJ72uu/cuTPzc/fv3z9YT7Nhw4bMY8eNG5dr20lS9+xmttzMDpnZzk73PWxmfzOz7aWvyRXpDkDZdOdt/O8k3dDF/b9y9zGlr9fL2xaAcksNu7u/LemzKvQCoILyfEB3j5m9X3qbf07Sg8xstpkVzazY2tqaY3MA8sga9t9IukDSGEkHJC1KeqC7L3X3grsXGhsbM24OQF6Zwu7uB9293d1PSPqtpPHlbQtAuWUKu5kN6fTjFEnJcxwA6kLqPLuZrZZ0taSBZrZP0i8lXW1mYyS5pBZJZTnh+tFHHw3WFyxYkPm50+YuBw8eHKw3NDQk1j7//PPg2I8//jhYX7NmTbB+yy23BOv1bPPmzYm1SZMmBcceOXIkWD9x4kSwPmXKlMRa6N9TkqZPnx6sT5s2LVhfuHBh5vGhYzrySA27u3fV1bIK9AKggjhcFogEYQciQdiBSBB2IBKEHYiEuXvVNlYoFLxYLCbWZ86cGRz/wgsvJNbGjh2bua9KGz16dLC+fPnyYN3MytlOVbW3tyfWnnzyyeDYdevWBetbtmwJ1qv5u32yfv36Beu7du1KrA0dOjTzdguFgorFYpe/MOzZgUgQdiAShB2IBGEHIkHYgUgQdiAShB2IRF1dSnr48OGZx4bm76WePVfdk/Xq1SuxNm/evODYtHraZc527NgRrIekLbM9f/78YD10eq2Uby49K/bsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5Eoq7m2UeNGpV57KpVq4L1tEsDo+dJW2Houuuuy/zcaWM//fTTYP3xxx8P1h966KHE2kUXXRQcmxV7diAShB2IBGEHIkHYgUgQdiAShB2IBGEHIlFX8+xTp04N1q+55prE2v333x8cO3LkyGB9/PjxwTrQ2a233hqspy3ZHFrGu2bz7GZ2rpm9ZWbNZtZkZveW7h9gZhvN7MPS7TkV6RBAWXTnbXybpPvdfZSkCZLuNrOLJT0gaZO7j5C0qfQzgDqVGnZ3P+Du75W+PyqpWdJQSTdJWlF62ApJN1eoRwBlcEof0JnZMEk/lrRF0mB3PyB1/IcgaVDCmNlmVjSzYto1wwBUTrfDbmZnSnpJ0n3u/vfujnP3pe5ecPdC2okLACqnW2E3s97qCPoqd/9mac2DZjakVB8i6VBlWgRQDqlTb9ZxDeZlkprdvfMauxskzZD0WOn2lbzNpF3u+bnnnkuspZ2SePnllwfraafAzp07N7HGtF18jh07lmv8aadV/xCX7syzXynpF5I+MLPtpfvmqyPkfzSzWZL+IulnFekQQFmkht3d/ywpaZf7k/K2A6BSOFwWiARhByJB2IFIEHYgEoQdiERdneKaZsSIEYm15ubm4NgnnngiWF+0aFGwvnLlysTapZdeGhw7c+bMYD1tjp8jD+vPa6+9Fqw3NDQE6+PGjStnO93Cnh2IBGEHIkHYgUgQdiAShB2IBGEHIkHYgUiYu1dtY4VCwYvFYtW2dyq+/PLLYP2ll15KrC1fvjw4dvPmzcF62pzs7bffHqwvWLAgscYcfTZvvPFGsJ52bMTEiROD9XXr1gXrWRUKBRWLxS7PUmXPDkSCsAORIOxAJAg7EAnCDkSCsAORIOxAJJhnr4K9e/cG64sXL85V79u3b2LtggsuCI5Nu375qFGjgvW0a+anHUNQSUeOHEmsNTU1Bce++OKLwfrYsWOD9bVr1wbrw4cPD9azYp4dAGEHYkHYgUgQdiAShB2IBGEHIkHYgUikzrOb2bmSfi/pnyWdkLTU3X9tZg9L+g9JraWHznf310PPFes8e15p8/Sha+IfPnw4OPb48ePB+rZt24L1Tz75JFivV3369AnWH3zwwWB9/vz5wXrv3r1PuadyCM2zd2eRiDZJ97v7e2Z2lqR3zWxjqfYrdw+vvgCgLnRnffYDkg6Uvj9qZs2Shla6MQDldUp/s5vZMEk/lrSldNc9Zva+mS03s3MSxsw2s6KZFVtbW7t6CIAq6HbYzexMSS9Jus/d/y7pN5IukDRGHXv+LhdLc/el7l5w9wLXQwNqp1thN7Pe6gj6KndfJ0nuftDd2939hKTfSgqfEQGgplLDbmYmaZmkZnd/stP9Qzo9bIqkneVvD0C5dOfT+Csl/ULSB2a2vXTffEnTzGyMJJfUImlOBfqD0k+HXLJkSZU6+a5Dhw4F6+3t7Ym10Km55RA6vbZfv34V3XY96s6n8X+W1NW8XXBOHUB94Qg6IBKEHYgEYQciQdiBSBB2IBKEHYhEd+bZgUSDBg2qdQvoJvbsQCQIOxAJwg5EgrADkSDsQCQIOxAJwg5EoqpLNptZq6TO1x4eKCl8rePaqdfe6rUvid6yKmdv57t7l9d/q2rYv7Nxs6K7F2rWQEC99lavfUn0llW1euNtPBAJwg5EotZhX1rj7YfUa2/12pdEb1lVpbea/s0OoHpqvWcHUCWEHYhETcJuZjeY2S4z22NmD9SihyRm1mJmH5jZdjOr6frSpTX0DpnZzk73DTCzjWb2Yem2yzX2atTbw2b2t9Jrt93MJteot3PN7C0zazazJjO7t3R/TV+7QF9Ved2q/je7mfWStFvSv0naJ2mrpGnu/n9VbSSBmbVIKrh7zQ/AMLOrJH0h6ffufmnpvv+S9Jm7P1b6j/Icd//POuntYUlf1HoZ79JqRUM6LzMu6WZJ/64avnaBvm5RFV63WuzZx0va4+573f2YpD9IuqkGfdQ9d39b0mcn3X2TpBWl71eo45el6hJ6qwvufsDd3yt9f1TSN8uM1/S1C/RVFbUI+1BJf+308z7V13rvLulPZvaumc2udTNdGOzuB6SOXx5J9XZdqNRlvKvppGXG6+a1y7L8eV61CHtXS0nV0/zfle4+VtIkSXeX3q6ie7q1jHe1dLHMeF3Iuvx5XrUI+z5J53b6+YeS9tegjy65+/7S7SFJ61V/S1Ef/GYF3dJteGXFKqqnZby7WmZcdfDa1XL581qEfaukEWb2IzNrkPRzSRtq0Md3mFm/0gcnMrN+kn6q+luKeoOkGaXvZ0h6pYa9fEu9LOOdtMy4avza1Xz5c3ev+pekyer4RP4jSQ/VooeEvoZL2lH6aqp1b5JWq+Nt3XF1vCOaJemfJG2S9GHpdkAd9faipA8kva+OYA2pUW//qo4/Dd+XtL30NbnWr12gr6q8bhwuC0SCI+iASBB2IBKEHYgEYQciQdiBSBB2IBKEHYjE/wP8VOMwG0xOiAAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"main_dir = 'quickdraw-png_set1/'\n",
"\n",
"img = Image.open(os.path.join(main_dir, df.index[1]))\n",
"img = np.asarray(img, dtype=np.uint8)\n",
"print(img.shape)\n",
"plt.imshow(np.array(img), cmap='binary');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implementing a Custom Dataset Class"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we implement a custom `Dataset` for reading the images. The `__getitem__` method will\n",
"\n",
"1. read a single image from disk based on an `index` (more on batching later)\n",
"2. perform a custom image transformation (if a `transform` argument is provided in the `__init__` construtor)\n",
"3. return a single image and it's corresponding label"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"class QuickdrawDataset(Dataset):\n",
" \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n",
"\n",
" def __init__(self, txt_path, img_dir, transform=None):\n",
" \n",
" df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n",
" self.img_dir = img_dir\n",
" self.txt_path = txt_path\n",
" self.img_names = df.index.values\n",
" self.y = df['Label'].values\n",
" self.transform = transform\n",
"\n",
" def __getitem__(self, index):\n",
" img = Image.open(os.path.join(self.img_dir,\n",
" self.img_names[index]))\n",
" \n",
" if self.transform is not None:\n",
" img = self.transform(img)\n",
" \n",
" label = self.y[index]\n",
" return img, label\n",
"\n",
" def __len__(self):\n",
" return self.y.shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have created our custom Dataset class, let us add some custom transformations via the `transforms` utilities from `torchvision`, we\n",
"\n",
"1. normalize the images (here: dividing by 255)\n",
"2. converting the image arrays into PyTorch tensors\n",
"\n",
"Then, we initialize a Dataset instance for the training images using the 'quickdraw_png_set1_train.csv' label file (we omit the test set, but the same concepts apply).\n",
"\n",
"Finally, we initialize a `DataLoader` that allows us to read from the dataset."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# Note that transforms.ToTensor()\n",
"# already divides pixels by 255. internally\n",
"\n",
"custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n",
" transforms.ToTensor()])\n",
"\n",
"train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n",
" img_dir='quickdraw-png_set1/',\n",
" transform=custom_transform)\n",
"\n",
"train_loader = DataLoader(dataset=train_dataset,\n",
" batch_size=128,\n",
" shuffle=True,\n",
" num_workers=4) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's it, now we can iterate over an epoch using the train_loader as an iterator and use the features and labels from the training dataset for model training:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Iterating Through the Custom Dataset"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1 | Batch index: 0 | Batch size: 128\n",
"Epoch: 2 | Batch index: 0 | Batch size: 128\n"
]
}
],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"torch.manual_seed(0)\n",
"\n",
"num_epochs = 2\n",
"for epoch in range(num_epochs):\n",
"\n",
" for batch_idx, (x, y) in enumerate(train_loader):\n",
" \n",
" print('Epoch:', epoch+1, end='')\n",
" print(' | Batch index:', batch_idx, end='')\n",
" print(' | Batch size:', y.size()[0])\n",
" \n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just to make sure that the batches are being loaded correctly, let's print out the dimensions of the last batch:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([128, 1, 28, 28])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, each batch consists of 128 images, just as specified. However, one thing to keep in mind though is that\n",
"PyTorch uses a different image layout (which is more efficient when working with CUDA); here, the image axes are \"num_images x channels x height x width\" (NCHW) instead of \"num_images height x width x channels\" (NHWC):"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To visually check that the images that coming of the data loader are intact, let's swap the axes to NHWC and convert an image from a Torch Tensor to a NumPy array so that we can visualize the image via `imshow`:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([28, 28, 1])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_image = x[0].permute(1, 2, 0)\n",
"one_image.shape"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQxklEQVR4nO3df2xVZZoH8O9XnBGpw49CkQpIZcQoAivY4CaSiWTYiRIUiQ4OwRGjWfBXdKLRNe4fozEmqDtjJrhO0hlgmKVqMDMKJsalkjFmEjNSoEuLuOCS2kHQtpFfgqLAs3/0OFboec/lnnPvueX5fpLm3t7nvj2PV749t/c957w0M4jIme+svBsQkfJQ2EWcUNhFnFDYRZxQ2EWcOLucGxsxYoTV1dWVc5MirrS3t6O7u5t91VKFneS1AH4DYACA35vZ0tDz6+rq0NzcnGaTIhJQX18fWyv6bTzJAQD+E8B1ACYCWEByYrE/T0RKK83f7NMBfGhmu8zsKwAvA5ibTVsikrU0YR8N4O+9vt8dPfYdJBeTbCbZ3NXVlWJzIpJGmrD39SHAKcfemlmDmdWbWX1NTU2KzYlIGmnCvhvA2F7fjwGwJ107IlIqacK+EcAEkheR/D6AnwFYl01bIpK1oqfezOwYyfsA/Dd6pt5WmNm2zDoTkUylmmc3szcAvJFRLyJSQjpcVsQJhV3ECYVdxAmFXcQJhV3ECYVdxImyns/en4WuwvvQQw8Fx7744ovB+qBBg4L1oUOHButVVVWxtSVLlgTH3nrrrcG6nDm0ZxdxQmEXcUJhF3FCYRdxQmEXcUJhF3FCU28FamxsjK0999xzwbE333xzsF5dXR2s79u3L1h/++23Y2tr1qwJjtXUmx/as4s4obCLOKGwizihsIs4obCLOKGwizihsIs4oXn2Am3cuDG2ljRP/sorr2TdznfMnj07tpY0Ry9+aM8u4oTCLuKEwi7ihMIu4oTCLuKEwi7ihMIu4oSbefYtW7YE6wcPHgzWW1tbY2uTJk0qqqeshOb5Ozo6ytiJVLJUYSfZDuAQgOMAjplZfRZNiUj2stizzzSz7gx+joiUkP5mF3EibdgNwHqSm0gu7usJJBeTbCbZ3NXVlXJzIlKstGG/2symAbgOwL0kf3TyE8yswczqzay+pqYm5eZEpFipwm5me6LbTgCvApieRVMikr2iw06yiuQPvrkP4CcA2rJqTESylebT+PMBvErym5/zopm9mUlXRdi6dWuwPm3atJJtO3oNYg0fPjxYHzJkSLCetGRzZ2dnbK27OzxRknTd+MGDBwfraXpPGpu07TT1tD972LBhwXrS/7OkfzOlUHTYzWwXgH/KsBcRKSFNvYk4obCLOKGwizihsIs4obCLOHHGnOJ66aWXpho/b968YP21116Lrc2ZMyc4dsKECcH6oUOHgvX9+/cH68ePH4+t7dmzJzj2vffeC9a/+OKLYD3p1OCk+plq4cKFwfrq1avL1Mm3tGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXceKMmWc/evRoqvEzZ84M1pcuXRpbGz9+fHDs2Wfn9zIfO3YsWM+zt6TlpNPO4R84cCC2lnRsQ9LPXr58ebC+fv36YD0P2rOLOKGwizihsIs4obCLOKGwizihsIs4obCLOHHGzLMfOXIk1fiBAwcG65dcckmqn5+XPOfRkyRdjjmpnqfDhw8H601NTcF6e3t7bK2urq6IjpJpzy7ihMIu4oTCLuKEwi7ihMIu4oTCLuKEwi7iROVOwp6mDz74INX4/jqPntbnn38erCdds76qqipYDy19PGDAgODYSjZ9+vRU40PX689tnp3kCpKdJNt6PVZNsonkzui2co9+EBEAhb2N/wOAa0967FEAG8xsAoAN0fciUsESw25m7wD47KSH5wJYFd1fBeDGbNsSkawV+wHd+Wa2FwCi25FxTyS5mGQzyeaurq4iNyciaZX803gzazCzejOrr6mpKfXmRCRGsWH/lGQtAES3ndm1JCKlUGzY1wFYFN1fBGBtNu2ISKkkzrOTfAnANQBGkNwN4JcAlgJYQ/JOAB0AflrKJguxZcuWVOOnTJmSUSfZO3HiRLC+dm3879oXXnghOPatt94qqqdCzZo1K7a2cuXK4NgxY8Zk3U5mLrvssmA96fiDzZs3x9bmz59fVE9JEsNuZgtiSj/OuBcRKSEdLivihMIu4oTCLuKEwi7ihMIu4sQZc4pr0tTbRRddFKznednipOWBb7nllmD9zTffjK1dfPHFwbFPPfVUsF5dXR2sJy1N/Prrr8fWxo0bFxyb9N+9bNmyYH348OHBehpJp+decMEFwXpnZ/mPQ9OeXcQJhV3ECYVdxAmFXcQJhV3ECYVdxAmFXcQJN/PsU6dOLVMnp+ro6AjW58yZE6zv2LEjWF+xYkVsbdGiRbE1ADjrrHS/7++6665g/eOPP46tNTQ0BMc+88wzwfquXbuC9XfffTe2RjI4Nq2k4xM+++zkyzqWnvbsIk4o7CJOKOwiTijsIk4o7CJOKOwiTijsIk70q3n2o0ePxtaSlmwu1eV5AcDMgvWZM2cG60nnsydd7nnGjBnBep5Gjx4dW3viiSeCY5OWLr7jjjuC9ba2ttja5MmTg2PT0jy7iORGYRdxQmEXcUJhF3FCYRdxQmEXcUJhF3GiX82zt7a2xta+/vrr4Nhp06Zl3c4/tLS0BOtJ5103NjYG65U8j15Ko0aNSjX+2LFjGXVy+pLm2ZOucVAKiXt2kitIdpJs6/XY4yQ/JtkSfc0ubZsiklYhb+P/AODaPh5/zsyuiL7eyLYtEclaYtjN7B0A5T+2T0QyleYDuvtIbo3e5sculEZyMclmks1dXV0pNiciaRQb9t8C+CGAKwDsBfCruCeaWYOZ1ZtZfU1NTZGbE5G0igq7mX1qZsfN7ASA3wGYnm1bIpK1osJOsrbXt/MAxJ9LKCIVIXGeneRLAK4BMILkbgC/BHANySsAGIB2AEtK1+K3kq4NH1LK68YnzakOGTIkWE9aZ/ymm24K1s8555xgvb/68ssvU40fOHBgRp2cvko8nz0x7Ga2oI+Hl5egFxEpIR0uK+KEwi7ihMIu4oTCLuKEwi7iRL86xTU09TZy5Mjg2Nra2mA9jXHjxgXrq1evDtZvuOGGYP22224L1leuXBlbGzRoUHBsJevPU2/DhsUeQQ5Al5IWkRJS2EWcUNhFnFDYRZxQ2EWcUNhFnFDYRZw4Y+bZS3mp6LTmzJkTrD///PPB+gMPPBCshy6x/fLLLwfHTpkyJVjPU2iJ7kLkeepv0jx76L/tyJEjwbHFHjuhPbuIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIExU1z378+PFgPTSffP/992fdTtncc889wXrSXPjChQtja1dddVVw7JNPPhmsJ/VWyvPl+/P57EmXkg5JOtdd8+wiEqSwizihsIs4obCLOKGwizihsIs4obCLOFFR8+w7duwI1g8fPhxbK+WSzHmbMWNGsN7S0hJbW7IkvJr2ww8/HKw//fTTwfrtt98erI8dOza2ljQP3tTUFKwnOffcc1ONT6OU8+xjxowp6ucm7tlJjiX5F5LbSW4j+UD0eDXJJpI7o9vw2foikqtC3sYfA/CQmV0G4J8B3EtyIoBHAWwwswkANkTfi0iFSgy7me01s83R/UMAtgMYDWAugFXR01YBuLFEPYpIBk7rAzqSdQCmAvgbgPPNbC/Q8wsBQJ+LrZFcTLKZZHNXV1fKdkWkWAWHneR5AP4E4BdmdrDQcWbWYGb1ZlZfU1NTTI8ikoGCwk7ye+gJeqOZ/Tl6+FOStVG9FkBnaVoUkSwkTr2RJIDlALab2a97ldYBWARgaXS7Nm0zoUtFJ0k6DTRpOiPNVEneQpctXrNmTXDspk2bgvVnn302WF+2bFmwnvZy0CEPPvhgsN5fp9727duXYSffKmSe/WoAPwfQSrIleuwx9IR8Dck7AXQA+GlJOhSRTCSG3cz+CoAx5R9n246IlIoOlxVxQmEXcUJhF3FCYRdxQmEXcaKiTnE9++xwO5dffnlsbd26dcGxjzzySLA+cmSfR/sWtO1JkyYFxybVJ06cmGr80KFDg/WQK6+8MlhPWvI5Sei05K+++io4dsCAAcH64MGDi+qpHEp5imuxtGcXcUJhF3FCYRdxQmEXcUJhF3FCYRdxQmEXcaKi5tnnz59fdL27uzs49rzzzgvWQ8tBA8D7778fW2tsbAyOLdW86TdClxYOHR8AAJMnTw7W0x4DMH78+Nha0lx0z6UU+ifNs4tIbhR2EScUdhEnFHYRJxR2EScUdhEnFHYRJypqnj2NESNGBOt33313mTo51SeffBKst7W1Bevbtm0rup50/EBDQ0OwfvBgwYv/ZC7pPP2keuh6+sOHDy96LJA8j15VVRWsh+zfv7/osSHas4s4obCLOKGwizihsIs4obCLOKGwizihsIs4Ucj67GMB/BHAKAAnADSY2W9IPg7gXwF0RU99zMzeKFWj/dmoUaNS1WfNmpVlO6elo6MjWA+d5w8AH330UWwtaR3ypHrSed+hetLP3rlzZ7CeNP7AgQPB+tSpU2Nr119/fXBssQo5qOYYgIfMbDPJHwDYRLIpqj1nZv9Rks5EJFOFrM++F8De6P4hktsBjC51YyKSrdP6m51kHYCpAP4WPXQfya0kV5Ds8/hCkotJNpNs7urq6uspIlIGBYed5HkA/gTgF2Z2EMBvAfwQwBXo2fP/qq9xZtZgZvVmVl9TU5O+YxEpSkFhJ/k99AS90cz+DABm9qmZHTezEwB+B2B66doUkbQSw86eS3wuB7DdzH7d6/HaXk+bByB86paI5KqQT+OvBvBzAK0kW6LHHgOwgOQVAAxAO4AlJehPcnbhhRemqkvlKOTT+L8C6OsC3ppTF+lHdASdiBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTCruIEwq7iBMKu4gTNLPybYzsAtD72sIjAHSXrYHTU6m9VWpfgHorVpa9jTOzPq//Vtawn7JxstnM6nNrIKBSe6vUvgD1Vqxy9aa38SJOKOwiTuQd9oactx9Sqb1Val+AeitWWXrL9W92ESmfvPfsIlImCruIE7mEneS1JP+X5IckH82jhzgk20m2kmwh2ZxzLytIdpJs6/VYNckmkjuj2z7X2Mupt8dJfhy9di0kZ+fU21iSfyG5neQ2kg9Ej+f62gX6KsvrVva/2UkOALADwL8A2A1gI4AFZhZe6LtMSLYDqDez3A/AIPkjAJ8D+KOZTYoeewbAZ2a2NPpFOczM/q1CenscwOd5L+MdrVZU23uZcQA3ArgdOb52gb7mowyvWx579ukAPjSzXWb2FYCXAczNoY+KZ2bvAPjspIfnAlgV3V+Fnn8sZRfTW0Uws71mtjm6fwjAN8uM5/raBfoqizzCPhrA33t9vxuVtd67AVhPchPJxXk304fzzWwv0POPB8DInPs5WeIy3uV00jLjFfPaFbP8eVp5hL2vpaQqaf7vajObBuA6APdGb1elMAUt410ufSwzXhGKXf48rTzCvhvA2F7fjwGwJ4c++mRme6LbTgCvovKWov70mxV0o9vOnPv5h0paxruvZcZRAa9dnsuf5xH2jQAmkLyI5PcB/AzAuhz6OAXJquiDE5CsAvATVN5S1OsALIruLwKwNsdevqNSlvGOW2YcOb92uS9/bmZl/wIwGz2fyP8fgH/Po4eYvsYD+J/oa1vevQF4CT1v675GzzuiOwEMB7ABwM7otrqCevsvAK0AtqInWLU59TYDPX8abgXQEn3Nzvu1C/RVltdNh8uKOKEj6EScUNhFnFDYRZxQ2EWcUNhFnFDYRZxQ2EWc+H9MFRFikCzOSwAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# note that imshow also works fine with scaled\n",
"# images in [0, 1] range.\n",
"plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='binary');"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch : 1.10.1\n",
"PIL : 9.0.1\n",
"sys : 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27) \n",
"[GCC 9.3.0]\n",
"torchvision: 0.11.2\n",
"pandas : 1.2.5\n",
"numpy : 1.22.2\n",
"matplotlib : 3.3.4\n",
"\n"
]
}
],
"source": [
"%watermark -iv"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}