{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom DataLoader Example for PNG files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Illustration of how we can efficiently iterate through custom (image) datasets. For this, suppose \n",
    "- mnist_train, mnist_valid, and mnist_test are image folders you created with your own custom images\n",
    "- mnist_train.csv, mnist_valid.csv, and mnist_test.csv are tables that store the image names with their associated class labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.1\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1\n",
      "pandas 0.24.0\n",
      "numpy 1.15.4\n",
      "matplotlib 3.0.2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch,pandas,numpy,matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1) Inspecting the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x11ea383c8>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADidJREFUeJzt3X+MVfWZx/HPI1J/UBJRIhBBKajrbgixOjGbYAiTRqKmEWtSUyQrK6SDscRtsn8sEqUjpGo2trv9qzqNhMG0tjXKiqRuqcZUqxvjiFhtWUCbWcoyGTQ2AdRYfjz7xxw2I8z9npl7z487PO9XQu6957nnnic3fOacc8+Pr7m7AMRzVt0NAKgH4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/ENTZVS7MzDidECiZu9to3tfSmt/MbjSz3Wb2vpmtaeWzAFTLmj2338wmSNoj6QZJ+yW9KWmpu/8xMQ9rfqBkVaz5r5P0vrv/yd3/Kunnkpa08HkAKtRK+C+R9Odhr/dn077AzLrMrM/M+lpYFoCCtfKD30ibFqdt1rt7j6Qeic1+oJ20subfL2nWsNczJR1orR0AVWkl/G9KusLMvmJmX5L0LUlbi2kLQNma3ux392NmtlrSryVNkLTR3f9QWGcAStX0ob6mFsY+P1C6Sk7yATB+EX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxBU00N0S5KZ9Us6LOm4pGPu3lFEUyiOWXrA1ilTpiTrU6dOTdZXrFiRrC9btqxhbebMmcl580aQfuihh5L1+++/v2HtjjvuSM47f/78ZH3Pnj3Jem9vb7J+/PjxZL0KLYU/0+nuHxXwOQAqxGY/EFSr4XdJ283sLTPrKqIhANVodbN/gbsfMLOLJf3GzP7b3V8Z/obsjwJ/GIA209Ka390PZI8HJW2RdN0I7+lx9w5+DATaS9PhN7NJZjb55HNJiyW9V1RjAMrVymb/NElbskNJZ0v6mbv/ZyFdASid5R1LLXRhZtUtLJDbbrutYe3mm29OznvXXXcV3U5ljh07lqynjrV3dnYm550zZ05TPZ105ZVXJusffPBBS5+f4u7pkzsyHOoDgiL8QFCEHwiK8ANBEX4gKMIPBMWhvjZw1VVXJev33ntvsn7nnXc2rJ133nlN9YTWcKgPQNsi/EBQhB8IivADQRF+ICjCDwRF+IGgirh7L3LkXVb75JNPJusXXHBBke0UaseOHcn6hx9+2LB2/fXXJ+edNGlSUz1VYXBwMFn/9NNPK+qkeaz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAojvMXYPLkycn6unXrkvUyj+M//fTTyXre7a9feOGFZP35559P1q+55pqGtUWLFiXnLdPrr7+erD/66KPJ+jvvvJOsDwwMjLmnqrHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgco/zm9lGSV+XdNDd52XTLpT0C0mzJfVLut3d/1Jem+3t3HPPTdYvuuiiUpe/efPmhrWVK1cm5z1x4kRLy161alWyvn79+oa1c845p6Vl59m2bVvD2tKlS5Pzjofr8Vs1mjX/Jkk3njJtjaSX3P0KSS9lrwGMI7nhd/dXJH18yuQlknqz572Sbi24LwAla3aff5q7D0hS9nhxcS0BqELp5/abWZekrrKXA2Bsml3zD5rZDEnKHg82eqO797h7h7t3NLksACVoNvxbJS3Pni+X9Fwx7QCoSm74zewpSf8l6W/MbL+ZrZT0iKQbzGyvpBuy1wDGkdx9fndvdED0awX3Mm6l7k0vSTt37kzW58yZ09Lyjxw50rCWdxzfLD2Ue09PT7K+fPnyZH3ChAnJekpe7xs2bEjWH3744Ya1o0ePNtXTmYQz/ICgCD8QFOEHgiL8QFCEHwiK8ANBcevuCmzfvj1Z7+zsTNanTJmSrC9YsKBhbfbs2cl516xJX5C5YsWKZL1MqUN1UvpyYeRjzQ8ERfiBoAg/EBThB4Ii/EBQhB8IivADQZm7V7cws+oWNo4sW7YsWU/dmrvdpS437u7uTs772GOPJeufffZZMy2d8dw9fZ12hjU/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTF9fzjQN5w0eeff35FnZzu0KFDyXrqfgBbtmwpuh2MAWt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwgq93p+M9so6euSDrr7vGxat6RvSzo5NvVad/9V7sK4nr8pb7/9drI+f/78ijo5XW9vb7Je533/oyryev5Nkm4cYfq/ufvV2b/c4ANoL7nhd/dXJH1cQS8AKtTKPv9qM/u9mW00s/R4UgDaTrPh/7GkuZKuljQg6QeN3mhmXWbWZ2Z9TS4LQAmaCr+7D7r7cXc/Ieknkq5LvLfH3TvcvaPZJgEUr6nwm9mMYS+/Iem9YtoBUJXcS3rN7ClJiyRNNbP9kr4naZGZXS3JJfVLWlVijwBKkBt+d186wuQnSugF49DevXvrbgFN4gw/ICjCDwRF+IGgCD8QFOEHgiL8QFAM0V2B6dOnJ+v33Xdfsn733Xcn62efXd8d2F977bVkfeHChRV1gpMYohtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBMUQ3RWYO3dusr569eqKOinepZde2nR93759RbeDMWDNDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZy/ABMnTkzW165dW+ryjxw50rD24IMPJuddt25dsj558uRkfdasWcn6Pffc07D2wAMPJOc9evRoso7WsOYHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaBy79tvZrMkbZY0XdIJST3u/iMzu1DSLyTNltQv6XZ3/0vOZ52R9+2//PLLk/Xdu3eXuvzu7u6GtQ0bNiTn7ezsTNZffPHFZloalXnz5iXru3btKm3ZZ7Ii79t/TNI/u/vfSvp7Sd8xs7+TtEbSS+5+haSXstcAxonc8Lv7gLvvyJ4flrRL0iWSlkjqzd7WK+nWspoEULwx7fOb2WxJX5X0hqRp7j4gDf2BkHRx0c0BKM+oz+03sy9LekbSd939kNmoditkZl2SupprD0BZRrXmN7OJGgr+T9392WzyoJnNyOozJB0caV5373H3DnfvKKJhAMXIDb8NreKfkLTL3X84rLRV0vLs+XJJzxXfHoCyjGazf4Gkf5D0rpntzKatlfSIpF+a2UpJ+yR9s5wW299ll11W6/I3bdrUsHbWWem/7zfddFPB3WC8yA2/u/9OUqMd/K8V2w6AqnCGHxAU4QeCIvxAUIQfCIrwA0ERfiAobt1dgM8//7zW5W/durVhLe/219dee23R7XzBq6++2rDW399f6rKRxpofCIrwA0ERfiAowg8ERfiBoAg/EBThB4LKvXV3oQs7Q2/dnTdEd+o4vCQtXry4yHYqNTg4mKzfcsstDWt9fX1FtwMVe+tuAGcgwg8ERfiBoAg/EBThB4Ii/EBQhB8IiuP8FVi4cGGy/vLLL1fUyek++eSTZH39+vXJ+uOPP56sHz58eMw9oTUc5weQRPiBoAg/EBThB4Ii/EBQhB8IivADQeUe5zezWZI2S5ou6YSkHnf/kZl1S/q2pA+zt65191/lfFbI4/xAlUZ7nH804Z8haYa77zCzyZLeknSrpNslHXH3R0fbFOEHyjfa8OeO2OPuA5IGsueHzWyXpEtaaw9A3ca0z29msyV9VdIb2aTVZvZ7M9toZlMazNNlZn1mxj2bgDYy6nP7zezLkn4r6fvu/qyZTZP0kSSXtEFDuwYrcj6DzX6gZIXt80uSmU2UtE3Sr939hyPUZ0va5u7zcj6H8AMlK+zCHjMzSU9I2jU8+NkPgSd9Q9J7Y20SQH1G82v/9ZJelfSuhg71SdJaSUslXa2hzf5+SauyHwdTn8WaHyhZoZv9RSH8QPm4nh9AEuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo3Bt4FuwjSf8z7PXUbFo7atfe2rUvid6aVWRvl432jZVez3/aws363L2jtgYS2rW3du1Lordm1dUbm/1AUIQfCKru8PfUvPyUdu2tXfuS6K1ZtfRW6z4/gPrUveYHUJNawm9mN5rZbjN738zW1NFDI2bWb2bvmtnOuocYy4ZBO2hm7w2bdqGZ/cbM9maPIw6TVlNv3Wb2v9l3t9PMbq6pt1lm9rKZ7TKzP5jZP2XTa/3uEn3V8r1VvtlvZhMk7ZF0g6T9kt6UtNTd/1hpIw2YWb+kDnev/ZiwmS2UdETS5pOjIZnZv0r62N0fyf5wTnH3f2mT3ro1xpGbS+qt0cjS/6gav7siR7wuQh1r/uskve/uf3L3v0r6uaQlNfTR9tz9FUkfnzJ5iaTe7Hmvhv7zVK5Bb23B3QfcfUf2/LCkkyNL1/rdJfqqRR3hv0TSn4e93q/2GvLbJW03s7fMrKvuZkYw7eTISNnjxTX3c6rckZurdMrI0m3z3TUz4nXR6gj/SKOJtNMhhwXufo2kmyR9J9u8xej8WNJcDQ3jNiDpB3U2k40s/Yyk77r7oTp7GW6Evmr53uoI/35Js4a9ninpQA19jMjdD2SPByVt0dBuSjsZPDlIavZ4sOZ+/p+7D7r7cXc/IeknqvG7y0aWfkbST9392Wxy7d/dSH3V9b3VEf43JV1hZl8xsy9J+pakrTX0cRozm5T9ECMzmyRpsdpv9OGtkpZnz5dLeq7GXr6gXUZubjSytGr+7tptxOtaTvLJDmX8u6QJkja6+/crb2IEZjZHQ2t7aeiKx5/V2ZuZPSVpkYau+hqU9D1J/yHpl5IulbRP0jfdvfIf3hr0tkhjHLm5pN4ajSz9hmr87ooc8bqQfjjDD4iJM/yAoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1f7GaMHSyoBEtAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "im = Image.open('mnist_train/1.png')\n",
    "plt.imshow(im)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Array Dimensions (28, 28)\n",
      "\n",
      "[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   1  18  38 136 227 255\n",
      "  254 132   0  90 136  98   3   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0  82 156 253 253 253 253 253\n",
      "  253 249 154 219 253 253  35   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0  40 150 244 253 253 253 253 253 253\n",
      "  253 253 253 253 253 253  35   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0  74 237 253 253 253 253 253 203 182 242\n",
      "  253 253 253 253 253 230  25   0   0   0]\n",
      " [  0   0   0   0   0   0   0  13 200 253 253 253 168 164  91  14  64 246\n",
      "  253 253 253 195  79  32   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0  21 219 253 253 159   2   0   0 103 233 253\n",
      "  253 253 177  10   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0 171 253 253 147   0   1 155 250 253 253\n",
      "  251 126   5   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0 101 236 253 206  32 152 253 253 253 253\n",
      "  130   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0  91 253 253 253 253 253 253 241 113\n",
      "    9   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0  91 243 253 253 253 253 239  81   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0 207 253 253 253 253 158   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0 207 253 253 253 253 121   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0  24 145 249 253 253 253 253 194   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0  59 253 253 253 253 253 253 224  30   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   5 181 253 253 241 114 240 253 253 136   5\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0  36 253 253 253 125   0  65 253 253 253  41\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0  67 253 253 253  29   2 138 253 253 253  41\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0  60 253 253 253 207 202 253 253 253 192   9\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   5 183 253 253 253 253 253 253 230  52   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0  62 253 253 253 253 242 116  13   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]\n",
      " [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0\n",
      "    0   0   0   0   0   0   0   0   0   0]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "im_array = np.array(im)\n",
    "print('Array Dimensions', im_array.shape)\n",
    "print()\n",
    "print(im_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(256, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Class Label</th>\n",
       "      <th>File Name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5</td>\n",
       "      <td>0.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8</td>\n",
       "      <td>1.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>8</td>\n",
       "      <td>2.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>3.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>9</td>\n",
       "      <td>4.png</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Class Label File Name\n",
       "0            5     0.png\n",
       "1            8     1.png\n",
       "2            8     2.png\n",
       "3            0     3.png\n",
       "4            9     4.png"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train = pd.read_csv('mnist_train.csv')\n",
    "print(df_train.shape)\n",
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(256, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Class Label</th>\n",
       "      <th>File Name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>256.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8</td>\n",
       "      <td>257.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7</td>\n",
       "      <td>258.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>259.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7</td>\n",
       "      <td>260.png</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Class Label File Name\n",
       "0            0   256.png\n",
       "1            8   257.png\n",
       "2            7   258.png\n",
       "3            4   259.png\n",
       "4            7   260.png"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_valid = pd.read_csv('mnist_valid.csv')\n",
    "print(df_valid.shape)\n",
    "df_valid.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(256, 2)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Class Label</th>\n",
       "      <th>File Name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4</td>\n",
       "      <td>512.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>513.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>6</td>\n",
       "      <td>514.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>8</td>\n",
       "      <td>515.png</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>516.png</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Class Label File Name\n",
       "0            4   512.png\n",
       "1            0   513.png\n",
       "2            6   514.png\n",
       "3            8   515.png\n",
       "4            4   516.png"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test = pd.read_csv('mnist_test.csv')\n",
    "print(df_test.shape)\n",
    "df_test.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2) Custom Dataset Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "import os\n",
    "\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self, csv_path, img_dir, transform=None):\n",
    "    \n",
    "        df = pd.read_csv(csv_path)\n",
    "        self.img_dir = img_dir\n",
    "        self.img_names = df['File Name']\n",
    "        self.y = df['Class Label']\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.img_names[index]))\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        label = self.y[index]\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3) Custom Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "# Note that transforms.ToTensor()\n",
    "# already divides pixels by 255. internally\n",
    "\n",
    "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.), # not necessary\n",
    "                                       transforms.ToTensor()\n",
    "                                      ])\n",
    "\n",
    "train_dataset = MyDataset(csv_path='mnist_train.csv',\n",
    "                          img_dir='mnist_train',\n",
    "                          transform=custom_transform)\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=32,\n",
    "                          shuffle=True, # want to shuffle the dataset\n",
    "                          num_workers=4) # number processes/CPUs to use"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<br>\n",
    "<br>\n",
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4) Iterating Through the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 1 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 2 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 3 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 4 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 5 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 6 | Batch size: 32\n",
      "Epoch: 1 | Batch index: 7 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 1 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 2 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 3 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 4 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 5 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 6 | Batch size: 32\n",
      "Epoch: 2 | Batch index: 7 | Batch size: 32\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(0)\n",
    "\n",
    "num_epochs = 2\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32, 784])\n"
     ]
    }
   ],
   "source": [
    "x_image_as_vector = x.view(-1, 28*28)\n",
    "print(x_image_as_vector.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        ...,\n",
       "\n",
       "\n",
       "        [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          ...,\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "          [0., 0., 0.,  ..., 0., 0., 0.]]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}