{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.1\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Using PyTorch Dataset Loading Utilities for Custom Datasets (Images from Quickdraw)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook provides an example for how to load an image dataset, stored as individual PNG files, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official\n",
    "\n",
    "- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n",
    "- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n",
    "\n",
    "In this example, we are using the Quickdraw dataset consisting of handdrawn objects, which is available at https://quickdraw.withgoogle.com. \n",
    "\n",
    "To execute the following examples, you need to download the \".npy\" (bitmap files in NumPy). You don't need to download all of the 345 categories but only a subset you are interested in. The groups/subsets can be individually downloaded from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap\n",
    "\n",
    "Unfortunately, the Google cloud storage currently does not support selecting and downloading multiple groups at once. Thus, in order to download all groups most coneniently, we need to use their `gsutil` (https://cloud.google.com/storage/docs/gsutil_install) tool. If you want to install that, you can then use \n",
    "\n",
    "    mkdir quickdraw-npy\n",
    "    gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy\n",
    "\n",
    "Note that if you download the whole dataset, this will take up 37 Gb of storage space.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After downloading the dataset to a local directory, `quickdraw-npy`, the next step is to select certain groups we are interested in analyzing. Let's say we are interested in the following groups defined in the `label_dict` in the next code cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_dict = {\n",
    "         \"lollipop\": 0,\n",
    "         \"binoculars\": 1,\n",
    "         \"mouse\": 2,\n",
    "         \"basket\": 3,\n",
    "         \"penguin\": 4,\n",
    "         \"washing machine\": 5,\n",
    "         \"canoe\": 6,\n",
    "         \"eyeglasses\": 7,\n",
    "         \"beach\": 8,\n",
    "         \"screwdriver\": 9,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dictionary values shall represent class labels that we could use for a classification task, for example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conversion to PNG files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we are going to convert the groups we are interested in (specified in the dictionary above) to individual PNG files using a helper function (note that this might take a while):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load utilities from ../helper.py\n",
    "import sys\n",
    "sys.path.insert(0, '..') \n",
    "from helper import quickdraw_npy_to_imagefile\n",
    "\n",
    "    \n",
    "quickdraw_npy_to_imagefile(inpath='quickdraw-npy',\n",
    "                           outpath='quickdraw-png_set1',\n",
    "                           subset=label_dict.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing into train/valid/test subsets and creating a label files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For convenience, let's create a CSV file mapping file names to class labels. First, let's collect the files and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num paths: 1515745\n",
      "Num labels: 1515745\n"
     ]
    }
   ],
   "source": [
    "paths, labels = [], []\n",
    "\n",
    "main_dir = 'quickdraw-png_set1/'\n",
    "\n",
    "for d in os.listdir(main_dir):\n",
    "    subdir = os.path.join(main_dir, d)\n",
    "    if not os.path.isdir(subdir):\n",
    "        continue\n",
    "    for f in os.listdir(subdir):\n",
    "        path = os.path.join(d, f)\n",
    "        paths.append(path)\n",
    "        labels.append(label_dict[d])\n",
    "        \n",
    "print('Num paths:', len(paths))\n",
    "print('Num labels:', len(labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we shuffle the dataset and assign 70% of the dataset for training, 10% for validation, and 20% for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlxtend.preprocessing import shuffle_arrays_unison\n",
    "\n",
    "\n",
    "paths2, labels2 = shuffle_arrays_unison(arrays=[np.array(paths), np.array(labels)], random_seed=3)\n",
    "\n",
    "\n",
    "cut1 = int(len(paths)*0.7)\n",
    "cut2 = int(len(paths)*0.8)\n",
    "\n",
    "paths_train, labels_train = paths2[:cut1], labels2[:cut1]\n",
    "paths_valid, labels_valid = paths2[cut1:cut2], labels2[cut1:cut2]\n",
    "paths_test, labels_test = paths2[cut2:], labels2[cut2:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let us create a CSV file that maps the file paths to the class labels (here only shown for the training set for simplicity):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Path</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>penguin/penguin_182463.png</th>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mouse/mouse_139942.png</th>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>screwdriver/screwdriver_066105.png</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>beach/beach_026711.png</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>eyeglasses/eyeglasses_035833.png</th>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    Label\n",
       "Path                                     \n",
       "penguin/penguin_182463.png              4\n",
       "mouse/mouse_139942.png                  2\n",
       "screwdriver/screwdriver_066105.png      9\n",
       "beach/beach_026711.png                  8\n",
       "eyeglasses/eyeglasses_035833.png        7"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame(\n",
    "    {'Path': paths_train,\n",
    "     'Label': labels_train,\n",
    "    })\n",
    "\n",
    "df = df.set_index('Path')\n",
    "df.to_csv('quickdraw_png_set1_train.csv')\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's open one of the images to make sure they look ok:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAENNJREFUeJzt3XuMVHWaxvHnFRtBWoKEFjsK2ypE14DbLA2KrgYcwUtMdIwDwx/C6jhMjCQ7ZmK8hQzRbCBmZ2YlrhpmIDRmBpnooP5h2CFGo2MWoSEo4GXHIDuDtN1NuMhF7u/+0YVptet3ijqn6hT9+34S09311Kl6UzMPp6t/p84xdxeA+JyV9wAA8kH5gUhRfiBSlB+IFOUHIkX5gUhRfiBSlB+IFOUHInV2NZ9s2LBh3tTUVM2nBKKyfft27dq1y0q5b6rym9ktkp6R1E/S79x9Yej+TU1NamtrS/OUAAJaWlpKvm/Zv/abWT9J/yXpVklXSpppZleW+3gAqivNe/6Jkj5z923uflTSS5LuyGYsAJWWpvwXSfp7j593FG77FjObY2ZtZtbW1dWV4ukAZClN+Xv7o8L3Ph/s7ovdvcXdWxoaGlI8HYAspSn/Dkkjevx8saSd6cYBUC1pyr9e0mgzu8TM+kv6saTXsxkLQKWVvdTn7sfNbK6k/1b3Ut9Sd9+a2WR9yAcffBDM33///WC+b9++YH7WWcX/Db/yyvACTNLSEG/V+q5U6/zu/oakNzKaBUAVcXgvECnKD0SK8gORovxApCg/ECnKD0Sqqp/n76uS1ulvuummYH7gwIEsx8lU0vkXko4TCOVJ244fPz6YDxkyJJgjjD0/ECnKD0SK8gORovxApCg/ECnKD0SKpb4SbdmypWh26623Bre9+OKLg/nq1auD+fDhw4O5+/dOoPSNTZs2BbdNOptyUr5+/fpg/sorrxTNQnOXYvTo0cE8tJQ4YcKEsreVpMbGxmA+atSoYF4L2PMDkaL8QKQoPxApyg9EivIDkaL8QKQoPxAp1vkLtm3bFsynTZtWNBs8eHBw2zVr1gTzpOMA0pg0aVKqPK3Qx5U3btwY3DbtMQihfMWKFcFt03rppZeC+YwZMyr6/KVgzw9EivIDkaL8QKQoPxApyg9EivIDkaL8QKRSrfOb2XZJ+yWdkHTc3cMfgs5Re3t7MJ86dWowP3nyZNEsz3X8WldfX180u+GGG4LbJuVfffVVMJ8/f37R7Nlnnw1ue/bZ4Wp8/fXXwXzgwIHBvBZkcZDPFHfflcHjAKgifu0HIpW2/C7pz2a2wczmZDEQgOpI+2v/de6+08wukLTGzD5x93d63qHwj8IcSRo5cmTKpwOQlVR7fnffWfjaKWmVpIm93Gexu7e4e0tDQ0OapwOQobLLb2aDzOy8U99Lmiap+CluAdSUNL/2D5e0ysxOPc4f3D18DmoANaPs8rv7Nkn/lOEsqezZsyeY33zzzcF89+7dwfytt94qmiWdPx69Szpvf2trazB/7LHHgvmuXcVXoB944IHgtvfdd18wHzduXDA/ePBgMK8FLPUBkaL8QKQoPxApyg9EivIDkaL8QKT6zKm7b7vttmAeusS2FP7oqSRNmTKlaFY41qGoIUOGBPO0QrPX1dWleuz+/fsH8+bm5mA+fvz4otmSJUuC265duzaYT548OZgvWrSoaDZ27NjgtocOHQrmSTo6OlJtXw3s+YFIUX4gUpQfiBTlByJF+YFIUX4gUpQfiFSfWef/8ssvg3nSWYTuvffesp87aU34yJEjZT+2JB0/fjyY79+/P9XjhyR9VHrZsmXB/IUXXiiaJf1vsnLlymA+ffr0YJ7GueeeG8wHDRoUzDs7O7McpyLY8wORovxApCg/ECnKD0SK8gORovxApCg/EKk+s86fdPrspMs5L1y4MMtxojFr1qxg/uKLLxbN5s6dG9y2kuv4aSUdo9DV1VWlScrHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUglrvOb2VJJt0vqdPcxhduGSlopqUnSdknT3T38we+cJX0+O0+HDx8O5i+//HIw37FjR9Es6dz211xzTTBPsnz58mD+9ttvF82++OKLVM+dp+HDhwfzvvJ5/mWSbvnObY9KetPdR0t6s/AzgDNIYvnd/R1Ju79z8x2SWgvft0q6M+O5AFRYue/5h7t7uyQVvl6Q3UgAqqHif/Azszlm1mZmbWfC8c5ALMotf4eZNUpS4WvRv264+2J3b3H3lqQPQwConnLL/7qk2YXvZ0t6LZtxAFRLYvnNbIWk/5F0uZntMLOfSFooaaqZ/VXS1MLPAM4giev87j6zSPSDjGfps55++ulgvmDBgmC+d+/eLMf5ltbW1mCe9Hn9JJdccknR7PPPP0/12HlKWufv6Oio0iTl4wg/IFKUH4gU5QciRfmBSFF+IFKUH4hUnzl1d55WrVoVzB955JFUj3/PPfcE8+eff75odvfddwe3ffDBB4N50vZJH5UOLfW99957wW1r2bBhw4L5hx9+WKVJyseeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSLHOn4E1a9YE8wsvvDCYJ53C+qyzyv83+oknngjm119/fTBfvXp1ML/rrruC+aWXXlo0W7FiRXDbY8eOBfO6urpgXkl79oTPVF9fX1+lScrHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUixzp+BgwcPBvPBgwcH8zTr+EmuvfbaYJ50DMKrr74azJPW+ceOHVs0O3r0aHDbrVu3BvPm5uZgXknr1q0L5tOmTavSJOVjzw9EivIDkaL8QKQoPxApyg9EivIDkaL8QKQS1/nNbKmk2yV1uvuYwm3zJf1UUlfhbo+7+xuVGrLWNTU1BfOVK1cG8xMnTgTzfv36ne5I30g6hiB0Xn1J6urqCuZJWlpayt62ra0tmFdynb+9vT2YJ52D4eqrr85ynIooZc+/TNItvdz+G3dvLvwXbfGBM1Vi+d39HUm7qzALgCpK855/rpl9aGZLzez8zCYCUBXllv95SZdJapbULulXxe5oZnPMrM3M2tK+fwSQnbLK7+4d7n7C3U9K+q2kiYH7Lnb3FndvaWhoKHdOABkrq/xm1tjjxx9K2pLNOACqpZSlvhWSJksaZmY7JP1S0mQza5bkkrZL+lkFZwRQAYnld/eZvdy8pAKznLEmTJgQzI8cORLMH3rooWC+aNGi056pVJ988kkwnz17dqrHHzFiRNEs6VwCGzZsCOb3339/WTOVYu3atam2nzix6DvhmsERfkCkKD8QKcoPRIryA5Gi/ECkKD8QKU7dnYHbb789mM+bNy+YP/XUU8E8adlpxowZRbPzzjsvuG3SpabHjx8fzNNIeuykj/RWUtKpuQcMGBDMx4wZk+U4FcGeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSLHOXwVPPvlkMB85cmQwX7ZsWTB/+OGHi2buHtz2iiuuCOahYwjSSjqt94IFC4J50kelzznnnNOe6ZSkdf5x48YF87q6urKfu1rY8wORovxApCg/ECnKD0SK8gORovxApCg/ECnW+WtA0imok/LQ5aR37twZ3DbpEt2VXK9O+jz/0aNHg/nmzZuDeZrLg19++eXBvJKXB68W9vxApCg/ECnKD0SK8gORovxApCg/ECnKD0QqcZ3fzEZIWi7pQkknJS1292fMbKiklZKaJG2XNN3dwyeBR0U0NjaWleXtqquuSrX9Rx99FMzTrPM/99xzZW97pihlz39c0i/c/R8lXSPpQTO7UtKjkt5099GS3iz8DOAMkVh+d293942F7/dL+ljSRZLukNRauFurpDsrNSSA7J3We34za5I0TtL7koa7e7vU/Q+EpAuyHg5A5ZRcfjOrl/SKpJ+7+1ensd0cM2szs7aurq5yZgRQASWV38zq1F3837v7nwo3d5hZYyFvlNTZ27buvtjdW9y9paGhIYuZAWQgsfxmZpKWSPrY3X/dI3pd0uzC97MlvZb9eAAqpZSP9F4n6R5Jm81sU+G2xyUtlPRHM/uJpL9J+lFlRkRfdfbZ6T5RnnRacoQlvvru/hdJViT+QbbjAKgWjvADIkX5gUhRfiBSlB+IFOUHIkX5gUj1mVN3DxgwIJh3dvZ6ACJydPz48VTb9+vXL6NJ4sSeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSPWZdf5Ro0YF83fffTeY79lT/lnHzz///LK3jdmxY8dSbd+/f/+MJokTe34gUpQfiBTlByJF+YFIUX4gUpQfiBTlByLVZ9b5x4wZE8z37t0bzIcOHZrlOKcl6VwESXlI2mMQkp77sssuC+aTJk0qmqW9fHhdXV2q7WPHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUglrvOb2QhJyyVdKOmkpMXu/oyZzZf0U0ldhbs+7u5vVGrQJDNnzgzmSdeCP3z4cDA/cuRI0ezQoUPBbZMkHYOQdB36Ss62b9++YP7pp58G83nz5hXNTp48WdZMp7DOn04pB/kcl/QLd99oZudJ2mBmawrZb9z9Pyo3HoBKSSy/u7dLai98v9/MPpZ0UaUHA1BZp/We38yaJI2T9H7hprlm9qGZLTWzXo8jNbM5ZtZmZm1dXV293QVADkouv5nVS3pF0s/d/StJz0u6TFKzun8z+FVv27n7YndvcfeWhoaGDEYGkIWSym9mdeou/u/d/U+S5O4d7n7C3U9K+q2kiZUbE0DWEstvZiZpiaSP3f3XPW7v+ZGsH0rakv14ACqllL/2XyfpHkmbzWxT4bbHJc00s2ZJLmm7pJ9VZMISDRw4MJjPmjWrSpOgpwMHDhTN1q1bF9x269atwXzKlCllzYRupfy1/y+SrJcotzV9AOlxhB8QKcoPRIryA5Gi/ECkKD8QKcoPRKrPnLobtam+vr5oduONNwa3TcqRDnt+IFKUH4gU5QciRfmBSFF+IFKUH4gU5QciZUmnhc70ycy6JP1fj5uGSdpVtQFOT63OVqtzScxWrixn+wd3L+l8eVUt//ee3KzN3VtyGyCgVmer1bkkZitXXrPxaz8QKcoPRCrv8i/O+flDanW2Wp1LYrZy5TJbru/5AeQn7z0/gJzkUn4zu8XMPjWzz8zs0TxmKMbMtpvZZjPbZGZtOc+y1Mw6zWxLj9uGmtkaM/tr4Wuvl0nLabb5ZvZF4bXbZGa35TTbCDN7y8w+NrOtZvZvhdtzfe0Cc+XyulX9134z6yfpfyVNlbRD0npJM939o6oOUoSZbZfU4u65rwmb2Q2SDkha7u5jCrc9LWm3uy8s/MN5vrs/UiOzzZd0IO8rNxcuKNPY88rSku6U9K/K8bULzDVdObxueez5J0r6zN23uftRSS9JuiOHOWqeu78jafd3br5DUmvh+1Z1/5+n6orMVhPcvd3dNxa+3y/p1JWlc33tAnPlIo/yXyTp7z1+3qHauuS3S/qzmW0wszl5D9OL4YXLpp+6fPoFOc/zXYlXbq6m71xZumZeu3KueJ21PMrf29V/amnJ4Tp3/2dJt0p6sPDrLUpT0pWbq6WXK0vXhHKveJ21PMq/Q9KIHj9fLGlnDnP0yt13Fr52Slql2rv6cMepi6QWvnbmPM83aunKzb1dWVo18NrV0hWv8yj/ekmjzewSM+sv6ceSXs9hju8xs0GFP8TIzAZJmqbau/rw65JmF76fLem1HGf5llq5cnOxK0sr59eu1q54nctBPoWljP+U1E/SUnf/96oP0Qszu1Tde3up+8zGf8hzNjNbIWmyuj/11SHpl5JelfRHSSMl/U3Sj9y96n94KzLbZHX/6vrNlZtPvceu8mz/IuldSZslnSzc/Li631/n9toF5pqpHF43jvADIsURfkCkKD8QKcoPRIryA5Gi/ECkKD8QKcoPRIryA5H6fzLsuOSnjNb7AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "main_dir = 'quickdraw-png_set1/'\n",
    "\n",
    "img = Image.open(os.path.join(main_dir, df.index[99]))\n",
    "img = np.asarray(img, dtype=np.uint8)\n",
    "print(img.shape)\n",
    "plt.imshow(np.array(img), cmap='binary');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing a Custom Dataset Class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we implement a custom `Dataset` for reading the images. The `__getitem__` method will\n",
    "\n",
    "1. read a single image from disk based on an `index` (more on batching later)\n",
    "2. perform a custom image transformation (if a `transform` argument is provided in the `__init__` construtor)\n",
    "3. return a single image and it's corresponding label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QuickdrawDataset(Dataset):\n",
    "    \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n",
    "\n",
    "    def __init__(self, txt_path, img_dir, transform=None):\n",
    "    \n",
    "        df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n",
    "        self.img_dir = img_dir\n",
    "        self.txt_path = txt_path\n",
    "        self.img_names = df.index.values\n",
    "        self.y = df['Label'].values\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.img_names[index]))\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        label = self.y[index]\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have created our custom Dataset class, let us add some custom transformations via the `transforms` utilities from `torchvision`, we\n",
    "\n",
    "1. normalize the images (here: dividing by 255)\n",
    "2. converting the image arrays into PyTorch tensors\n",
    "\n",
    "Then, we initialize a Dataset instance for the training images using the 'quickdraw_png_set1_train.csv' label file (we omit the test set, but the same concepts apply).\n",
    "\n",
    "Finally, we initialize a `DataLoader` that allows us to read from the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note that transforms.ToTensor()\n",
    "# already divides pixels by 255. internally\n",
    "\n",
    "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n",
    "                                       transforms.ToTensor()])\n",
    "\n",
    "train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n",
    "                                 img_dir='quickdraw-png_set1/',\n",
    "                                 transform=custom_transform)\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=128,\n",
    "                          shuffle=True,\n",
    "                          num_workers=4) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it, now we can iterate over an epoch using the train_loader as an iterator and use the features and labels from the training dataset for model training:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iterating Through the Custom Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(0)\n",
    "\n",
    "num_epochs = 2\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just to make sure that the batches are being loaded correctly, let's print out the dimensions of the last batch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 1, 28, 28])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, each batch consists of 128 images, just as specified. However, one thing to keep in mind though is that\n",
    "PyTorch uses a different image layout (which is more efficient when working with CUDA); here, the image axes are \"num_images x channels x height x width\" (NCHW) instead of \"num_images height x width x channels\" (NHWC):"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To visually check that the images that coming of the data loader are intact, let's swap the axes to NHWC and convert an image from a Torch Tensor to a NumPy array so that we can visualize the image via `imshow`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28, 28, 1])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_image = x[0].permute(1, 2, 0)\n",
    "one_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADuZJREFUeJzt3X+MVfWZx/HPs0JDBEWEwRIYHbfRtYhZXK/EqFlZfzR2YxRNRElENv6AP6rZRqOroxE0mBCzteuv1OBCQNNqSVpxEo2WmE3EUBovaipdcCFmtAMjM0hVMCJBnv1jDs1U53zveH+dO/O8X4mZe89zv3Mer37m3Hu/95yvubsAxPN3RTcAoBiEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUGOaubMpU6Z4R0dHM3cJhNLd3a29e/facB5bU/jN7HJJj0k6RtJ/u/uK1OM7OjpULpdr2SWAhFKpNOzHVv2y38yOkfSUpB9LmilpgZnNrPb3AWiuWt7zz5G0090/cPdDkl6QdFV92gLQaLWEf7qkPw+635Nt+xtmttjMymZW7u/vr2F3AOqplvAP9aHCt84PdveV7l5y91JbW1sNuwNQT7WEv0dS+6D7MyTtrq0dAM1SS/jfknSamZ1qZt+TdL2krvq0BaDRqp7qc/fDZnabpNc0MNW32t3/VLfO0BR79+5N1l999dVk/eOPP07Wzz///KpqaLya5vnd/RVJr9SpFwBNxNd7gaAIPxAU4QeCIvxAUIQfCIrwA0E19Xz+qHp7e5P1F154IVlfv359sn7mmWfm1nbt2pUcW2ke/9ChQ8l6La677rpk/ZFHHknWTz755Hq2Ew5HfiAowg8ERfiBoAg/EBThB4Ii/EBQTPXVweHDh5P1iy++OFnfvn17TftPjZ86dWpy7F133ZWs33DDDcl6e3t7sv7UU0/l1h5++OHk2K6u9OUh7r333mT9/vvvz62ZDevq1qMaR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIp5/jp4+umnk/VK8/irVq1K1hctWpSsd3Z25tY2bdqUHLt8+fJkvVZ33313bu3GG2+seqwkPfDAA8n6rFmzcmtXX311cmwEHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKia5vnNrFvSfklfSzrs7qV6NDXSfPjhh8n6pEmTkvWbbrqppv3v27cvt7Zjx47k2Ern+0+ePDlZ37x5c7I+bty43Nr48eOTY9esWZOsVzrff+PGjbk15vnr8yWff3H39CLvAFoOL/uBoGoNv0v6nZltMbPF9WgIQHPU+rL/AnffbWZTJW0ws+3u/sbgB2R/FBZLLK8EtJKajvzuvjv72SfpRUlzhnjMSncvuXupra2tlt0BqKOqw29m483suKO3Jf1I0tZ6NQagsWp52X+SpBezSyCPkfQrd08v+QqgZVQdfnf/QNI/1rGXEWvatGnJ+qeffpqsV5orP++885L1sWPH5tY++eST5NgzzjgjWZ83b16yfuyxxybrqXUBnnvuueTYStdJ+PLLL5P1iRMnJuvRMdUHBEX4gaAIPxAU4QeCIvxAUIQfCIpLd9fBkiVLkvXHH388WZ87d26yfscddyTrR44cya1VWj584cKFyXqly2dXcuedd+bW3nnnneTY+fPn17TvSqdSR8eRHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYp6/DipdgnrLli3J+n333Zesr1ixIlmfMGFCbi273kKuSqfVVjql9/TTT0/W29vbc2vr169Pjj3llFOS9f379yfrZ511VrIeHUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwjK3L1pOyuVSl4ul5u2v9Gi0nnvF110UW6t0lz4mDHpr3qkrhUgSddcc02yfvzxx+fWXn755eTYvr6+qn+3JO3dm794dKV/75GqVCqpXC6nv9yR4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0FVnOw0s9WSrpDU5+6zsm0nSvq1pA5J3ZLmu/tfGtdmbJWW0f7qq69ya5XOx9+0aVOy3t/fn6y/9tpryfrkyZNza+eee25y7MaNG5P166+/PlkfrXP59TKcI/8aSZd/Y9s9kl5399MkvZ7dBzCCVAy/u78had83Nl8laW12e62k9OEFQMup9j3/Se7eK0nZz6n1awlAMzT8Az8zW2xmZTMrV3r/CKB5qg3/HjObJknZz9wzMNx9pbuX3L3U1tZW5e4A1Fu14e+StCi7vUjSS/VpB0CzVAy/mT0v6feS/sHMeszsZkkrJF1mZjskXZbdBzCCVJwIdfcFOaVL6twLcmzevDlZP3ToUG7t9ttvT45ds2ZNsv7ggw8m608++WSy3tvbm1s7ePBgcuxnn32WrF977bXJOtL4hh8QFOEHgiL8QFCEHwiK8ANBEX4gKM55HAG6u7urHjtz5sxkfeLEicn6o48+mqzfcsstyfq6detyax999FFy7OzZs5P1Sy5htrkWHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjm+UeAL774ouqx48ePr2Mn31bpewTLli1r6P5RPY78QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU8/wjgLsX3QJGIY78QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXl+M1st6QpJfe4+K9u2TNKtkvqzh3W6+yuNanK0O3DgQLJ+3HHHVf27d+7cmayfffbZVf9ujGzDOfKvkXT5ENt/7u6zs38IPjDCVAy/u78haV8TegHQRLW857/NzP5oZqvNbFLdOgLQFNWG/xeSfiBptqReST/Le6CZLTazspmV+/v78x4GoMmqCr+773H3r939iKRnJM1JPHalu5fcvdTW1lZtnwDqrKrwm9m0QXevlrS1Pu0AaJbhTPU9L2mupClm1iNpqaS5ZjZbkkvqlrSkgT0CaICK4Xf3BUNsXtWAXsLq7OxM1p955pmqf/f777+frDPPHxff8AOCIvxAUIQfCIrwA0ERfiAowg8ExaW7W8DcuXOT9SeeeKLq311pqg9xceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY528BV155ZbI+ffr0ZH3Xrl25Neb5kYcjPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTx/CxgzJv2f4eabb07WH3roodzatm3bquoJox9HfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IquI8v5m1S3pW0vclHZG00t0fM7MTJf1aUoekbknz3f0vjWs1rltvvTVZX758eW5t+/btybEHDx5M1seNG5esY+QazpH/sKQ73f2Hks6T9BMzmynpHkmvu/tpkl7P7gMYISqG39173f3t7PZ+SdskTZd0laS12cPWSprXqCYB1N93es9vZh2Szpb0B0knuXuvNPAHQtLUejcHoHGGHX4zmyDpN5J+6u6ff4dxi82sbGbl/v7+anoE0ADDCr+ZjdVA8H/p7r/NNu8xs2lZfZqkvqHGuvtKdy+5e6mtra0ePQOog4rhNzOTtErSNnd/dFCpS9Ki7PYiSS/Vvz0AjTKcU3ovkLRQ0ntm9m62rVPSCknrzOxmSR9JurYxLWLGjBnJ+hVXXJFb6+rqSo598803k/VLL700WcfIVTH87v6mJMspX1LfdgA0C9/wA4Ii/EBQhB8IivADQRF+ICjCDwTFpbtHgdQpv5Xm+Tds2JCsM88/enHkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgmOcfBS688MKqx27evLmOnWAk4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzz8KnHDCCbm1U089NTn288+HvfIaRhmO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVMV5fjNrl/SspO9LOiJppbs/ZmbLJN0qqT97aKe7v9KoRlGdc845J1nfunVrkzpBqxnOl3wOS7rT3d82s+MkbTGzoys9/Nzd/7Nx7QFolIrhd/deSb3Z7f1mtk3S9EY3BqCxvtN7fjPrkHS2pD9km24zsz+a2Wozm5QzZrGZlc2s3N/fP9RDABRg2OE3swmSfiPpp+7+uaRfSPqBpNkaeGXws6HGuftKdy+5e6mtra0OLQOoh2GF38zGaiD4v3T330qSu+9x96/d/YikZyTNaVybAOqtYvjNzCStkrTN3R8dtH3aoIddLYmPjYERZDif9l8gaaGk98zs3Wxbp6QFZjZbkkvqlrSkIR2iJkuXLk3We3p6mtQJWs1wPu1/U5INUWJOHxjB+IYfEBThB4Ii/EBQhB8IivADQRF+ICgu3T3KzZo1q6Y6Ri+O/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl783Zm1i/pw0Gbpkja27QGvptW7a1V+5LorVr17O0Udx/W9fKaGv5v7dys7O6lwhpIaNXeWrUvid6qVVRvvOwHgiL8QFBFh39lwftPadXeWrUvid6qVUhvhb7nB1Ccoo/8AApSSPjN7HIze9/MdprZPUX0kMfMus3sPTN718zKBfey2sz6zGzroG0nmtkGM9uR/RxymbSCeltmZruy5+5dM/vXgnprN7P/MbNtZvYnM/v3bHuhz12ir0Ket6a/7DezYyT9n6TLJPVIekvSAnf/36Y2ksPMuiWV3L3wOWEz+2dJByQ96+6zsm2PSNrn7iuyP5yT3P0/WqS3ZZIOFL1yc7agzLTBK0tLmifp31Tgc5foa74KeN6KOPLPkbTT3T9w90OSXpB0VQF9tDx3f0PSvm9svkrS2uz2Wg38z9N0Ob21BHfvdfe3s9v7JR1dWbrQ5y7RVyGKCP90SX8edL9HrbXkt0v6nZltMbPFRTczhJOyZdOPLp8+teB+vqniys3N9I2VpVvmuatmxet6KyL8Q63+00pTDhe4+z9J+rGkn2QvbzE8w1q5uVmGWFm6JVS74nW9FRH+Hkntg+7PkLS7gD6G5O67s599kl5U660+vOfoIqnZz76C+/mrVlq5eaiVpdUCz10rrXhdRPjfknSamZ1qZt+TdL2krgL6+BYzG599ECMzGy/pR2q91Ye7JC3Kbi+S9FKBvfyNVlm5OW9laRX83LXaiteFfMknm8r4L0nHSFrt7g83vYkhmNnfa+BoLw1c2fhXRfZmZs9LmquBs772SFoqab2kdZJOlvSRpGvdvekfvOX0NlcDL13/unLz0ffYTe7tQkkbJb0n6Ui2uVMD768Le+4SfS1QAc8b3/ADguIbfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvp/hck5eusCwwAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# note that imshow also works fine with scaled\n",
    "# images in [0, 1] range.\n",
    "plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='binary');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "pandas      0.23.4\n",
      "torchvision 0.2.1\n",
      "torch       1.0.0\n",
      "PIL.Image   5.3.0\n",
      "matplotlib  3.0.2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}