{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Using PyTorch Dataset Loading Utilities for Custom Datasets (Images from Quickdraw)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook provides an example for how to load an image dataset, stored as individual PNG files, using PyTorch's data loading utilities. For a more in-depth discussion, please see the official\n", "\n", "- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n", "- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n", "\n", "In this example, we are using the Quickdraw dataset consisting of handdrawn objects, which is available at https://quickdraw.withgoogle.com. \n", "\n", "To execute the following examples, you need to download the \".npy\" (bitmap files in NumPy). You don't need to download all of the 345 categories but only a subset you are interested in. The groups/subsets can be individually downloaded from https://console.cloud.google.com/storage/browser/quickdraw_dataset/full/numpy_bitmap\n", "\n", "Unfortunately, the Google cloud storage currently does not support selecting and downloading multiple groups at once. Thus, in order to download all groups most coneniently, we need to use their `gsutil` (https://cloud.google.com/storage/docs/gsutil_install) tool. If you want to install that, you can then use \n", "\n", " mkdir quickdraw-npy\n", " gsutil -m cp gs://quickdraw_dataset/full/numpy_bitmap/*.npy quickdraw-npy\n", "\n", "Note that if you download the whole dataset, this will take up 37 Gb of storage space.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import os\n", "\n", "import torch\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After downloading the dataset to a local directory, `quickdraw-npy`, the next step is to select certain groups we are interested in analyzing. Let's say we are interested in the following groups defined in the `label_dict` in the next code cell:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "label_dict = {\n", " \"lollipop\": 0,\n", " \"binoculars\": 1,\n", " \"mouse\": 2,\n", " \"basket\": 3,\n", " \"penguin\": 4,\n", " \"washing machine\": 5,\n", " \"canoe\": 6,\n", " \"eyeglasses\": 7,\n", " \"beach\": 8,\n", " \"screwdriver\": 9,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dictionary values shall represent class labels that we could use for a classification task, for example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conversion to PNG files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we are going to convert the groups we are interested in (specified in the dictionary above) to individual PNG files using a helper function (note that this might take a while):" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# load utilities from ../helper.py\n", "import sys\n", "sys.path.insert(0, '..') \n", "from helper import quickdraw_npy_to_imagefile\n", "\n", " \n", "quickdraw_npy_to_imagefile(inpath='quickdraw-npy',\n", " outpath='quickdraw-png_set1',\n", " subset=label_dict.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocessing into train/valid/test subsets and creating a label files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For convenience, let's create a CSV file mapping file names to class labels. First, let's collect the files and labels." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Num paths: 1515745\n", "Num labels: 1515745\n" ] } ], "source": [ "paths, labels = [], []\n", "\n", "main_dir = 'quickdraw-png_set1/'\n", "\n", "for d in os.listdir(main_dir):\n", " subdir = os.path.join(main_dir, d)\n", " if not os.path.isdir(subdir):\n", " continue\n", " for f in os.listdir(subdir):\n", " path = os.path.join(d, f)\n", " paths.append(path)\n", " labels.append(label_dict[d])\n", " \n", "print('Num paths:', len(paths))\n", "print('Num labels:', len(labels))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we shuffle the dataset and assign 70% of the dataset for training, 10% for validation, and 20% for testing." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from mlxtend.preprocessing import shuffle_arrays_unison\n", "\n", "\n", "paths2, labels2 = shuffle_arrays_unison(arrays=[np.array(paths), np.array(labels)], random_seed=3)\n", "\n", "\n", "cut1 = int(len(paths)*0.7)\n", "cut2 = int(len(paths)*0.8)\n", "\n", "paths_train, labels_train = paths2[:cut1], labels2[:cut1]\n", "paths_valid, labels_valid = paths2[cut1:cut2], labels2[cut1:cut2]\n", "paths_test, labels_test = paths2[cut2:], labels2[cut2:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let us create a CSV file that maps the file paths to the class labels (here only shown for the training set for simplicity):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Label</th>\n", " </tr>\n", " <tr>\n", " <th>Path</th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>penguin/penguin_182463.png</th>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>mouse/mouse_139942.png</th>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>screwdriver/screwdriver_066105.png</th>\n", " <td>9</td>\n", " </tr>\n", " <tr>\n", " <th>beach/beach_026711.png</th>\n", " <td>8</td>\n", " </tr>\n", " <tr>\n", " <th>eyeglasses/eyeglasses_035833.png</th>\n", " <td>7</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Label\n", "Path \n", "penguin/penguin_182463.png 4\n", "mouse/mouse_139942.png 2\n", "screwdriver/screwdriver_066105.png 9\n", "beach/beach_026711.png 8\n", "eyeglasses/eyeglasses_035833.png 7" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame(\n", " {'Path': paths_train,\n", " 'Label': labels_train,\n", " })\n", "\n", "df = df.set_index('Path')\n", "df.to_csv('quickdraw_png_set1_train.csv')\n", "\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's open one of the images to make sure they look ok:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAENNJREFUeJzt3XuMVHWaxvHnFRtBWoKEFjsK2ypE14DbLA2KrgYcwUtMdIwDwx/C6jhMjCQ7ZmK8hQzRbCBmZ2YlrhpmIDRmBpnooP5h2CFGo2MWoSEo4GXHIDuDtN1NuMhF7u/+0YVptet3ijqn6hT9+34S09311Kl6UzMPp6t/p84xdxeA+JyV9wAA8kH5gUhRfiBSlB+IFOUHIkX5gUhRfiBSlB+IFOUHInV2NZ9s2LBh3tTUVM2nBKKyfft27dq1y0q5b6rym9ktkp6R1E/S79x9Yej+TU1NamtrS/OUAAJaWlpKvm/Zv/abWT9J/yXpVklXSpppZleW+3gAqivNe/6Jkj5z923uflTSS5LuyGYsAJWWpvwXSfp7j593FG77FjObY2ZtZtbW1dWV4ukAZClN+Xv7o8L3Ph/s7ovdvcXdWxoaGlI8HYAspSn/Dkkjevx8saSd6cYBUC1pyr9e0mgzu8TM+kv6saTXsxkLQKWVvdTn7sfNbK6k/1b3Ut9Sd9+a2WR9yAcffBDM33///WC+b9++YH7WWcX/Db/yyvACTNLSEG/V+q5U6/zu/oakNzKaBUAVcXgvECnKD0SK8gORovxApCg/ECnKD0Sqqp/n76uS1ulvuummYH7gwIEsx8lU0vkXko4TCOVJ244fPz6YDxkyJJgjjD0/ECnKD0SK8gORovxApCg/ECnKD0SKpb4SbdmypWh26623Bre9+OKLg/nq1auD+fDhw4O5+/dOoPSNTZs2BbdNOptyUr5+/fpg/sorrxTNQnOXYvTo0cE8tJQ4YcKEsreVpMbGxmA+atSoYF4L2PMDkaL8QKQoPxApyg9EivIDkaL8QKQoPxAp1vkLtm3bFsynTZtWNBs8eHBw2zVr1gTzpOMA0pg0aVKqPK3Qx5U3btwY3DbtMQihfMWKFcFt03rppZeC+YwZMyr6/KVgzw9EivIDkaL8QKQoPxApyg9EivIDkaL8QKRSrfOb2XZJ+yWdkHTc3cMfgs5Re3t7MJ86dWowP3nyZNEsz3X8WldfX180u+GGG4LbJuVfffVVMJ8/f37R7Nlnnw1ue/bZ4Wp8/fXXwXzgwIHBvBZkcZDPFHfflcHjAKgifu0HIpW2/C7pz2a2wczmZDEQgOpI+2v/de6+08wukLTGzD5x93d63qHwj8IcSRo5cmTKpwOQlVR7fnffWfjaKWmVpIm93Gexu7e4e0tDQ0OapwOQobLLb2aDzOy8U99Lmiap+CluAdSUNL/2D5e0ysxOPc4f3D18DmoANaPs8rv7Nkn/lOEsqezZsyeY33zzzcF89+7dwfytt94qmiWdPx69Szpvf2trazB/7LHHgvmuXcVXoB944IHgtvfdd18wHzduXDA/ePBgMK8FLPUBkaL8QKQoPxApyg9EivIDkaL8QKT6zKm7b7vttmAeusS2FP7oqSRNmTKlaFY41qGoIUOGBPO0QrPX1dWleuz+/fsH8+bm5mA+fvz4otmSJUuC265duzaYT548OZgvWrSoaDZ27NjgtocOHQrmSTo6OlJtXw3s+YFIUX4gUpQfiBTlByJF+YFIUX4gUpQfiFSfWef/8ssvg3nSWYTuvffesp87aU34yJEjZT+2JB0/fjyY79+/P9XjhyR9VHrZsmXB/IUXXiiaJf1vsnLlymA+ffr0YJ7GueeeG8wHDRoUzDs7O7McpyLY8wORovxApCg/ECnKD0SK8gORovxApCg/EKk+s86fdPrspMs5L1y4MMtxojFr1qxg/uKLLxbN5s6dG9y2kuv4aSUdo9DV1VWlScrHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUglrvOb2VJJt0vqdPcxhduGSlopqUnSdknT3T38we+cJX0+O0+HDx8O5i+//HIw37FjR9Es6dz211xzTTBPsnz58mD+9ttvF82++OKLVM+dp+HDhwfzvvJ5/mWSbvnObY9KetPdR0t6s/AzgDNIYvnd/R1Ju79z8x2SWgvft0q6M+O5AFRYue/5h7t7uyQVvl6Q3UgAqqHif/Azszlm1mZmbWfC8c5ALMotf4eZNUpS4WvRv264+2J3b3H3lqQPQwConnLL/7qk2YXvZ0t6LZtxAFRLYvnNbIWk/5F0uZntMLOfSFooaaqZ/VXS1MLPAM4giev87j6zSPSDjGfps55++ulgvmDBgmC+d+/eLMf5ltbW1mCe9Hn9JJdccknR7PPPP0/12HlKWufv6Oio0iTl4wg/IFKUH4gU5QciRfmBSFF+IFKUH4hUnzl1d55WrVoVzB955JFUj3/PPfcE8+eff75odvfddwe3ffDBB4N50vZJH5UOLfW99957wW1r2bBhw4L5hx9+WKVJyseeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSLHOn4E1a9YE8wsvvDCYJ53C+qyzyv83+oknngjm119/fTBfvXp1ML/rrruC+aWXXlo0W7FiRXDbY8eOBfO6urpgXkl79oTPVF9fX1+lScrHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUixzp+BgwcPBvPBgwcH8zTr+EmuvfbaYJ50DMKrr74azJPW+ceOHVs0O3r0aHDbrVu3BvPm5uZgXknr1q0L5tOmTavSJOVjzw9EivIDkaL8QKQoPxApyg9EivIDkaL8QKQS1/nNbKmk2yV1uvuYwm3zJf1UUlfhbo+7+xuVGrLWNTU1BfOVK1cG8xMnTgTzfv36ne5I30g6hiB0Xn1J6urqCuZJWlpayt62ra0tmFdynb+9vT2YJ52D4eqrr85ynIooZc+/TNItvdz+G3dvLvwXbfGBM1Vi+d39HUm7qzALgCpK855/rpl9aGZLzez8zCYCUBXllv95SZdJapbULulXxe5oZnPMrM3M2tK+fwSQnbLK7+4d7n7C3U9K+q2kiYH7Lnb3FndvaWhoKHdOABkrq/xm1tjjxx9K2pLNOACqpZSlvhWSJksaZmY7JP1S0mQza5bkkrZL+lkFZwRQAYnld/eZvdy8pAKznLEmTJgQzI8cORLMH3rooWC+aNGi056pVJ988kkwnz17dqrHHzFiRNEs6VwCGzZsCOb3339/WTOVYu3atam2nzix6DvhmsERfkCkKD8QKcoPRIryA5Gi/ECkKD8QKU7dnYHbb789mM+bNy+YP/XUU8E8adlpxowZRbPzzjsvuG3SpabHjx8fzNNIeuykj/RWUtKpuQcMGBDMx4wZk+U4FcGeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSLHOXwVPPvlkMB85cmQwX7ZsWTB/+OGHi2buHtz2iiuuCOahYwjSSjqt94IFC4J50kelzznnnNOe6ZSkdf5x48YF87q6urKfu1rY8wORovxApCg/ECnKD0SK8gORovxApCg/ECnW+WtA0imok/LQ5aR37twZ3DbpEt2VXK9O+jz/0aNHg/nmzZuDeZrLg19++eXBvJKXB68W9vxApCg/ECnKD0SK8gORovxApCg/ECnKD0QqcZ3fzEZIWi7pQkknJS1292fMbKiklZKaJG2XNN3dwyeBR0U0NjaWleXtqquuSrX9Rx99FMzTrPM/99xzZW97pihlz39c0i/c/R8lXSPpQTO7UtKjkt5099GS3iz8DOAMkVh+d293942F7/dL+ljSRZLukNRauFurpDsrNSSA7J3We34za5I0TtL7koa7e7vU/Q+EpAuyHg5A5ZRcfjOrl/SKpJ+7+1ensd0cM2szs7aurq5yZgRQASWV38zq1F3837v7nwo3d5hZYyFvlNTZ27buvtjdW9y9paGhIYuZAWQgsfxmZpKWSPrY3X/dI3pd0uzC97MlvZb9eAAqpZSP9F4n6R5Jm81sU+G2xyUtlPRHM/uJpL9J+lFlRkRfdfbZ6T5RnnRacoQlvvru/hdJViT+QbbjAKgWjvADIkX5gUhRfiBSlB+IFOUHIkX5gUj1mVN3DxgwIJh3dvZ6ACJydPz48VTb9+vXL6NJ4sSeH4gU5QciRfmBSFF+IFKUH4gU5QciRfmBSPWZdf5Ro0YF83fffTeY79lT/lnHzz///LK3jdmxY8dSbd+/f/+MJokTe34gUpQfiBTlByJF+YFIUX4gUpQfiBTlByLVZ9b5x4wZE8z37t0bzIcOHZrlOKcl6VwESXlI2mMQkp77sssuC+aTJk0qmqW9fHhdXV2q7WPHnh+IFOUHIkX5gUhRfiBSlB+IFOUHIkX5gUglrvOb2QhJyyVdKOmkpMXu/oyZzZf0U0ldhbs+7u5vVGrQJDNnzgzmSdeCP3z4cDA/cuRI0ezQoUPBbZMkHYOQdB36Ss62b9++YP7pp58G83nz5hXNTp48WdZMp7DOn04pB/kcl/QLd99oZudJ2mBmawrZb9z9Pyo3HoBKSSy/u7dLai98v9/MPpZ0UaUHA1BZp/We38yaJI2T9H7hprlm9qGZLTWzXo8jNbM5ZtZmZm1dXV293QVADkouv5nVS3pF0s/d/StJz0u6TFKzun8z+FVv27n7YndvcfeWhoaGDEYGkIWSym9mdeou/u/d/U+S5O4d7n7C3U9K+q2kiZUbE0DWEstvZiZpiaSP3f3XPW7v+ZGsH0rakv14ACqllL/2XyfpHkmbzWxT4bbHJc00s2ZJLmm7pJ9VZMISDRw4MJjPmjWrSpOgpwMHDhTN1q1bF9x269atwXzKlCllzYRupfy1/y+SrJcotzV9AOlxhB8QKcoPRIryA5Gi/ECkKD8QKcoPRKrPnLobtam+vr5oduONNwa3TcqRDnt+IFKUH4gU5QciRfmBSFF+IFKUH4gU5QciZUmnhc70ycy6JP1fj5uGSdpVtQFOT63OVqtzScxWrixn+wd3L+l8eVUt//ee3KzN3VtyGyCgVmer1bkkZitXXrPxaz8QKcoPRCrv8i/O+flDanW2Wp1LYrZy5TJbru/5AeQn7z0/gJzkUn4zu8XMPjWzz8zs0TxmKMbMtpvZZjPbZGZtOc+y1Mw6zWxLj9uGmtkaM/tr4Wuvl0nLabb5ZvZF4bXbZGa35TTbCDN7y8w+NrOtZvZvhdtzfe0Cc+XyulX9134z6yfpfyVNlbRD0npJM939o6oOUoSZbZfU4u65rwmb2Q2SDkha7u5jCrc9LWm3uy8s/MN5vrs/UiOzzZd0IO8rNxcuKNPY88rSku6U9K/K8bULzDVdObxueez5J0r6zN23uftRSS9JuiOHOWqeu78jafd3br5DUmvh+1Z1/5+n6orMVhPcvd3dNxa+3y/p1JWlc33tAnPlIo/yXyTp7z1+3qHauuS3S/qzmW0wszl5D9OL4YXLpp+6fPoFOc/zXYlXbq6m71xZumZeu3KueJ21PMrf29V/amnJ4Tp3/2dJt0p6sPDrLUpT0pWbq6WXK0vXhHKveJ21PMq/Q9KIHj9fLGlnDnP0yt13Fr52Slql2rv6cMepi6QWvnbmPM83aunKzb1dWVo18NrV0hWv8yj/ekmjzewSM+sv6ceSXs9hju8xs0GFP8TIzAZJmqbau/rw65JmF76fLem1HGf5llq5cnOxK0sr59eu1q54nctBPoWljP+U1E/SUnf/96oP0Qszu1Tde3up+8zGf8hzNjNbIWmyuj/11SHpl5JelfRHSSMl/U3Sj9y96n94KzLbZHX/6vrNlZtPvceu8mz/IuldSZslnSzc/Li631/n9toF5pqpHF43jvADIsURfkCkKD8QKcoPRIryA5Gi/ECkKD8QKcoPRIryA5H6fzLsuOSnjNb7AAAAAElFTkSuQmCC\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "main_dir = 'quickdraw-png_set1/'\n", "\n", "img = Image.open(os.path.join(main_dir, df.index[99]))\n", "img = np.asarray(img, dtype=np.uint8)\n", "print(img.shape)\n", "plt.imshow(np.array(img), cmap='binary');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Implementing a Custom Dataset Class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we implement a custom `Dataset` for reading the images. The `__getitem__` method will\n", "\n", "1. read a single image from disk based on an `index` (more on batching later)\n", "2. perform a custom image transformation (if a `transform` argument is provided in the `__init__` construtor)\n", "3. return a single image and it's corresponding label" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class QuickdrawDataset(Dataset):\n", " \"\"\"Custom Dataset for loading Quickdraw images\"\"\"\n", "\n", " def __init__(self, txt_path, img_dir, transform=None):\n", " \n", " df = pd.read_csv(txt_path, sep=\",\", index_col=0)\n", " self.img_dir = img_dir\n", " self.txt_path = txt_path\n", " self.img_names = df.index.values\n", " self.y = df['Label'].values\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_names[index]))\n", " \n", " if self.transform is not None:\n", " img = self.transform(img)\n", " \n", " label = self.y[index]\n", " return img, label\n", "\n", " def __len__(self):\n", " return self.y.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have created our custom Dataset class, let us add some custom transformations via the `transforms` utilities from `torchvision`, we\n", "\n", "1. normalize the images (here: dividing by 255)\n", "2. converting the image arrays into PyTorch tensors\n", "\n", "Then, we initialize a Dataset instance for the training images using the 'quickdraw_png_set1_train.csv' label file (we omit the test set, but the same concepts apply).\n", "\n", "Finally, we initialize a `DataLoader` that allows us to read from the dataset." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Note that transforms.ToTensor()\n", "# already divides pixels by 255. internally\n", "\n", "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = QuickdrawDataset(txt_path='quickdraw_png_set1_train.csv',\n", " img_dir='quickdraw-png_set1/',\n", " transform=custom_transform)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=128,\n", " shuffle=True,\n", " num_workers=4) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's it, now we can iterate over an epoch using the train_loader as an iterator and use the features and labels from the training dataset for model training:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Iterating Through the Custom Dataset" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Batch index: 0 | Batch size: 128\n", "Epoch: 2 | Batch index: 0 | Batch size: 128\n" ] } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "torch.manual_seed(0)\n", "\n", "num_epochs = 2\n", "for epoch in range(num_epochs):\n", "\n", " for batch_idx, (x, y) in enumerate(train_loader):\n", " \n", " print('Epoch:', epoch+1, end='')\n", " print(' | Batch index:', batch_idx, end='')\n", " print(' | Batch size:', y.size()[0])\n", " \n", " x = x.to(device)\n", " y = y.to(device)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Just to make sure that the batches are being loaded correctly, let's print out the dimensions of the last batch:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 1, 28, 28])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, each batch consists of 128 images, just as specified. However, one thing to keep in mind though is that\n", "PyTorch uses a different image layout (which is more efficient when working with CUDA); here, the image axes are \"num_images x channels x height x width\" (NCHW) instead of \"num_images height x width x channels\" (NHWC):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visually check that the images that coming of the data loader are intact, let's swap the axes to NHWC and convert an image from a Torch Tensor to a NumPy array so that we can visualize the image via `imshow`:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([28, 28, 1])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "one_image = x[0].permute(1, 2, 0)\n", "one_image.shape" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADuZJREFUeJzt3X+MVfWZx/HPs0JDBEWEwRIYHbfRtYhZXK/EqFlZfzR2YxRNRElENv6AP6rZRqOroxE0mBCzteuv1OBCQNNqSVpxEo2WmE3EUBovaipdcCFmtAMjM0hVMCJBnv1jDs1U53zveH+dO/O8X4mZe89zv3Mer37m3Hu/95yvubsAxPN3RTcAoBiEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUGOaubMpU6Z4R0dHM3cJhNLd3a29e/facB5bU/jN7HJJj0k6RtJ/u/uK1OM7OjpULpdr2SWAhFKpNOzHVv2y38yOkfSUpB9LmilpgZnNrPb3AWiuWt7zz5G0090/cPdDkl6QdFV92gLQaLWEf7qkPw+635Nt+xtmttjMymZW7u/vr2F3AOqplvAP9aHCt84PdveV7l5y91JbW1sNuwNQT7WEv0dS+6D7MyTtrq0dAM1SS/jfknSamZ1qZt+TdL2krvq0BaDRqp7qc/fDZnabpNc0MNW32t3/VLfO0BR79+5N1l999dVk/eOPP07Wzz///KpqaLya5vnd/RVJr9SpFwBNxNd7gaAIPxAU4QeCIvxAUIQfCIrwA0E19Xz+qHp7e5P1F154IVlfv359sn7mmWfm1nbt2pUcW2ke/9ChQ8l6La677rpk/ZFHHknWTz755Hq2Ew5HfiAowg8ERfiBoAg/EBThB4Ii/EBQTPXVweHDh5P1iy++OFnfvn17TftPjZ86dWpy7F133ZWs33DDDcl6e3t7sv7UU0/l1h5++OHk2K6u9OUh7r333mT9/vvvz62ZDevq1qMaR34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIp5/jp4+umnk/VK8/irVq1K1hctWpSsd3Z25tY2bdqUHLt8+fJkvVZ33313bu3GG2+seqwkPfDAA8n6rFmzcmtXX311cmwEHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+IKia5vnNrFvSfklfSzrs7qV6NDXSfPjhh8n6pEmTkvWbbrqppv3v27cvt7Zjx47k2Ern+0+ePDlZ37x5c7I+bty43Nr48eOTY9esWZOsVzrff+PGjbk15vnr8yWff3H39CLvAFoOL/uBoGoNv0v6nZltMbPF9WgIQHPU+rL/AnffbWZTJW0ws+3u/sbgB2R/FBZLLK8EtJKajvzuvjv72SfpRUlzhnjMSncvuXupra2tlt0BqKOqw29m483suKO3Jf1I0tZ6NQagsWp52X+SpBezSyCPkfQrd08v+QqgZVQdfnf/QNI/1rGXEWvatGnJ+qeffpqsV5orP++885L1sWPH5tY++eST5NgzzjgjWZ83b16yfuyxxybrqXUBnnvuueTYStdJ+PLLL5P1iRMnJuvRMdUHBEX4gaAIPxAU4QeCIvxAUIQfCIpLd9fBkiVLkvXHH388WZ87d26yfscddyTrR44cya1VWj584cKFyXqly2dXcuedd+bW3nnnneTY+fPn17TvSqdSR8eRHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCYp6/DipdgnrLli3J+n333Zesr1ixIlmfMGFCbi273kKuSqfVVjql9/TTT0/W29vbc2vr169Pjj3llFOS9f379yfrZ511VrIeHUd+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwjK3L1pOyuVSl4ul5u2v9Gi0nnvF110UW6t0lz4mDHpr3qkrhUgSddcc02yfvzxx+fWXn755eTYvr6+qn+3JO3dm794dKV/75GqVCqpXC6nv9yR4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0FVnOw0s9WSrpDU5+6zsm0nSvq1pA5J3ZLmu/tfGtdmbJWW0f7qq69ya5XOx9+0aVOy3t/fn6y/9tpryfrkyZNza+eee25y7MaNG5P166+/PlkfrXP59TKcI/8aSZd/Y9s9kl5399MkvZ7dBzCCVAy/u78had83Nl8laW12e62k9OEFQMup9j3/Se7eK0nZz6n1awlAMzT8Az8zW2xmZTMrV3r/CKB5qg3/HjObJknZz9wzMNx9pbuX3L3U1tZW5e4A1Fu14e+StCi7vUjSS/VpB0CzVAy/mT0v6feS/sHMeszsZkkrJF1mZjskXZbdBzCCVJwIdfcFOaVL6twLcmzevDlZP3ToUG7t9ttvT45ds2ZNsv7ggw8m608++WSy3tvbm1s7ePBgcuxnn32WrF977bXJOtL4hh8QFOEHgiL8QFCEHwiK8ANBEX4gKM55HAG6u7urHjtz5sxkfeLEicn6o48+mqzfcsstyfq6detyax999FFy7OzZs5P1Sy5htrkWHPmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjm+UeAL774ouqx48ePr2Mn31bpewTLli1r6P5RPY78QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU8/wjgLsX3QJGIY78QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxBUxXl+M1st6QpJfe4+K9u2TNKtkvqzh3W6+yuNanK0O3DgQLJ+3HHHVf27d+7cmayfffbZVf9ujGzDOfKvkXT5ENt/7u6zs38IPjDCVAy/u78haV8TegHQRLW857/NzP5oZqvNbFLdOgLQFNWG/xeSfiBptqReST/Le6CZLTazspmV+/v78x4GoMmqCr+773H3r939iKRnJM1JPHalu5fcvdTW1lZtnwDqrKrwm9m0QXevlrS1Pu0AaJbhTPU9L2mupClm1iNpqaS5ZjZbkkvqlrSkgT0CaICK4Xf3BUNsXtWAXsLq7OxM1p955pmqf/f777+frDPPHxff8AOCIvxAUIQfCIrwA0ERfiAowg8ExaW7W8DcuXOT9SeeeKLq311pqg9xceQHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCY528BV155ZbI+ffr0ZH3Xrl25Neb5kYcjPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExTx/CxgzJv2f4eabb07WH3roodzatm3bquoJox9HfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IquI8v5m1S3pW0vclHZG00t0fM7MTJf1aUoekbknz3f0vjWs1rltvvTVZX758eW5t+/btybEHDx5M1seNG5esY+QazpH/sKQ73f2Hks6T9BMzmynpHkmvu/tpkl7P7gMYISqG39173f3t7PZ+SdskTZd0laS12cPWSprXqCYB1N93es9vZh2Szpb0B0knuXuvNPAHQtLUejcHoHGGHX4zmyDpN5J+6u6ff4dxi82sbGbl/v7+anoE0ADDCr+ZjdVA8H/p7r/NNu8xs2lZfZqkvqHGuvtKdy+5e6mtra0ePQOog4rhNzOTtErSNnd/dFCpS9Ki7PYiSS/Vvz0AjTKcU3ovkLRQ0ntm9m62rVPSCknrzOxmSR9JurYxLWLGjBnJ+hVXXJFb6+rqSo598803k/VLL700WcfIVTH87v6mJMspX1LfdgA0C9/wA4Ii/EBQhB8IivADQRF+ICjCDwTFpbtHgdQpv5Xm+Tds2JCsM88/enHkB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgmOcfBS688MKqx27evLmOnWAk4cgPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0Exzz8KnHDCCbm1U089NTn288+HvfIaRhmO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVMV5fjNrl/SspO9LOiJppbs/ZmbLJN0qqT97aKe7v9KoRlGdc845J1nfunVrkzpBqxnOl3wOS7rT3d82s+MkbTGzoys9/Nzd/7Nx7QFolIrhd/deSb3Z7f1mtk3S9EY3BqCxvtN7fjPrkHS2pD9km24zsz+a2Wozm5QzZrGZlc2s3N/fP9RDABRg2OE3swmSfiPpp+7+uaRfSPqBpNkaeGXws6HGuftKdy+5e6mtra0OLQOoh2GF38zGaiD4v3T330qSu+9x96/d/YikZyTNaVybAOqtYvjNzCStkrTN3R8dtH3aoIddLYmPjYERZDif9l8gaaGk98zs3Wxbp6QFZjZbkkvqlrSkIR2iJkuXLk3We3p6mtQJWs1wPu1/U5INUWJOHxjB+IYfEBThB4Ii/EBQhB8IivADQRF+ICgu3T3KzZo1q6Y6Ri+O/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl783Zm1i/pw0Gbpkja27QGvptW7a1V+5LorVr17O0Udx/W9fKaGv5v7dys7O6lwhpIaNXeWrUvid6qVVRvvOwHgiL8QFBFh39lwftPadXeWrUvid6qVUhvhb7nB1Ccoo/8AApSSPjN7HIze9/MdprZPUX0kMfMus3sPTN718zKBfey2sz6zGzroG0nmtkGM9uR/RxymbSCeltmZruy5+5dM/vXgnprN7P/MbNtZvYnM/v3bHuhz12ir0Ket6a/7DezYyT9n6TLJPVIekvSAnf/36Y2ksPMuiWV3L3wOWEz+2dJByQ96+6zsm2PSNrn7iuyP5yT3P0/WqS3ZZIOFL1yc7agzLTBK0tLmifp31Tgc5foa74KeN6KOPLPkbTT3T9w90OSXpB0VQF9tDx3f0PSvm9svkrS2uz2Wg38z9N0Ob21BHfvdfe3s9v7JR1dWbrQ5y7RVyGKCP90SX8edL9HrbXkt0v6nZltMbPFRTczhJOyZdOPLp8+teB+vqniys3N9I2VpVvmuatmxet6KyL8Q63+00pTDhe4+z9J+rGkn2QvbzE8w1q5uVmGWFm6JVS74nW9FRH+Hkntg+7PkLS7gD6G5O67s599kl5U660+vOfoIqnZz76C+/mrVlq5eaiVpdUCz10rrXhdRPjfknSamZ1qZt+TdL2krgL6+BYzG599ECMzGy/pR2q91Ye7JC3Kbi+S9FKBvfyNVlm5OW9laRX83LXaiteFfMknm8r4L0nHSFrt7g83vYkhmNnfa+BoLw1c2fhXRfZmZs9LmquBs772SFoqab2kdZJOlvSRpGvdvekfvOX0NlcDL13/unLz0ffYTe7tQkkbJb0n6Ui2uVMD768Le+4SfS1QAc8b3/ADguIbfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvp/hck5eusCwwAAAAAASUVORK5CYII=\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# note that imshow also works fine with scaled\n", "# images in [0, 1] range.\n", "plt.imshow(one_image.to(torch.device('cpu')).squeeze(), cmap='binary');" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "pandas 0.23.4\n", "torchvision 0.2.1\n", "torch 1.0.0\n", "PIL.Image 5.3.0\n", "matplotlib 3.0.2\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }