{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom DataLoader Example for PNG files" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Illustration of how we can efficiently iterate through custom (image) datasets. For this, suppose \n", "- mnist_train, mnist_valid, and mnist_test are image folders you created with your own custom images\n", "- mnist_train.csv, mnist_valid.csv, and mnist_test.csv are tables that store the image names with their associated class labels" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.1\n", "pandas 0.24.0\n", "numpy 1.15.4\n", "matplotlib 3.0.2\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch,pandas,numpy,matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1) Inspecting the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADidJREFUeJzt3X+MVfWZx/HPI1J/UBJRIhBBKajrbgixOjGbYAiTRqKmEWtSUyQrK6SDscRtsn8sEqUjpGo2trv9qzqNhMG0tjXKiqRuqcZUqxvjiFhtWUCbWcoyGTQ2AdRYfjz7xxw2I8z9npl7z487PO9XQu6957nnnic3fOacc8+Pr7m7AMRzVt0NAKgH4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/ENTZVS7MzDidECiZu9to3tfSmt/MbjSz3Wb2vpmtaeWzAFTLmj2338wmSNoj6QZJ+yW9KWmpu/8xMQ9rfqBkVaz5r5P0vrv/yd3/Kunnkpa08HkAKtRK+C+R9Odhr/dn077AzLrMrM/M+lpYFoCCtfKD30ibFqdt1rt7j6Qeic1+oJ20subfL2nWsNczJR1orR0AVWkl/G9KusLMvmJmX5L0LUlbi2kLQNma3ux392NmtlrSryVNkLTR3f9QWGcAStX0ob6mFsY+P1C6Sk7yATB+EX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxBU00N0S5KZ9Us6LOm4pGPu3lFEUyiOWXrA1ilTpiTrU6dOTdZXrFiRrC9btqxhbebMmcl580aQfuihh5L1+++/v2HtjjvuSM47f/78ZH3Pnj3Jem9vb7J+/PjxZL0KLYU/0+nuHxXwOQAqxGY/EFSr4XdJ283sLTPrKqIhANVodbN/gbsfMLOLJf3GzP7b3V8Z/obsjwJ/GIA209Ka390PZI8HJW2RdN0I7+lx9w5+DATaS9PhN7NJZjb55HNJiyW9V1RjAMrVymb/NElbskNJZ0v6mbv/ZyFdASid5R1LLXRhZtUtLJDbbrutYe3mm29OznvXXXcV3U5ljh07lqynjrV3dnYm550zZ05TPZ105ZVXJusffPBBS5+f4u7pkzsyHOoDgiL8QFCEHwiK8ANBEX4gKMIPBMWhvjZw1VVXJev33ntvsn7nnXc2rJ133nlN9YTWcKgPQNsi/EBQhB8IivADQRF+ICjCDwRF+IGgirh7L3LkXVb75JNPJusXXHBBke0UaseOHcn6hx9+2LB2/fXXJ+edNGlSUz1VYXBwMFn/9NNPK+qkeaz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAojvMXYPLkycn6unXrkvUyj+M//fTTyXre7a9feOGFZP35559P1q+55pqGtUWLFiXnLdPrr7+erD/66KPJ+jvvvJOsDwwMjLmnqrHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgco/zm9lGSV+XdNDd52XTLpT0C0mzJfVLut3d/1Jem+3t3HPPTdYvuuiiUpe/efPmhrWVK1cm5z1x4kRLy161alWyvn79+oa1c845p6Vl59m2bVvD2tKlS5Pzjofr8Vs1mjX/Jkk3njJtjaSX3P0KSS9lrwGMI7nhd/dXJH18yuQlknqz572Sbi24LwAla3aff5q7D0hS9nhxcS0BqELp5/abWZekrrKXA2Bsml3zD5rZDEnKHg82eqO797h7h7t3NLksACVoNvxbJS3Pni+X9Fwx7QCoSm74zewpSf8l6W/MbL+ZrZT0iKQbzGyvpBuy1wDGkdx9fndvdED0awX3Mm6l7k0vSTt37kzW58yZ09Lyjxw50rCWdxzfLD2Ue09PT7K+fPnyZH3ChAnJekpe7xs2bEjWH3744Ya1o0ePNtXTmYQz/ICgCD8QFOEHgiL8QFCEHwiK8ANBcevuCmzfvj1Z7+zsTNanTJmSrC9YsKBhbfbs2cl516xJX5C5YsWKZL1MqUN1UvpyYeRjzQ8ERfiBoAg/EBThB4Ii/EBQhB8IivADQZm7V7cws+oWNo4sW7YsWU/dmrvdpS437u7uTs772GOPJeufffZZMy2d8dw9fZ12hjU/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTF9fzjQN5w0eeff35FnZzu0KFDyXrqfgBbtmwpuh2MAWt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwgq93p+M9so6euSDrr7vGxat6RvSzo5NvVad/9V7sK4nr8pb7/9drI+f/78ijo5XW9vb7Je533/oyryev5Nkm4cYfq/ufvV2b/c4ANoL7nhd/dXJH1cQS8AKtTKPv9qM/u9mW00s/R4UgDaTrPh/7GkuZKuljQg6QeN3mhmXWbWZ2Z9TS4LQAmaCr+7D7r7cXc/Ieknkq5LvLfH3TvcvaPZJgEUr6nwm9mMYS+/Iem9YtoBUJXcS3rN7ClJiyRNNbP9kr4naZGZXS3JJfVLWlVijwBKkBt+d186wuQnSugF49DevXvrbgFN4gw/ICjCDwRF+IGgCD8QFOEHgiL8QFAM0V2B6dOnJ+v33Xdfsn733Xcn62efXd8d2F977bVkfeHChRV1gpMYohtAEuEHgiL8QFCEHwiK8ANBEX4gKMIPBMUQ3RWYO3dusr569eqKOinepZde2nR93759RbeDMWDNDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZy/ABMnTkzW165dW+ryjxw50rD24IMPJuddt25dsj558uRkfdasWcn6Pffc07D2wAMPJOc9evRoso7WsOYHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaBy79tvZrMkbZY0XdIJST3u/iMzu1DSLyTNltQv6XZ3/0vOZ52R9+2//PLLk/Xdu3eXuvzu7u6GtQ0bNiTn7ezsTNZffPHFZloalXnz5iXru3btKm3ZZ7Ii79t/TNI/u/vfSvp7Sd8xs7+TtEbSS+5+haSXstcAxonc8Lv7gLvvyJ4flrRL0iWSlkjqzd7WK+nWspoEULwx7fOb2WxJX5X0hqRp7j4gDf2BkHRx0c0BKM+oz+03sy9LekbSd939kNmoditkZl2SupprD0BZRrXmN7OJGgr+T9392WzyoJnNyOozJB0caV5373H3DnfvKKJhAMXIDb8NreKfkLTL3X84rLRV0vLs+XJJzxXfHoCyjGazf4Gkf5D0rpntzKatlfSIpF+a2UpJ+yR9s5wW299ll11W6/I3bdrUsHbWWem/7zfddFPB3WC8yA2/u/9OUqMd/K8V2w6AqnCGHxAU4QeCIvxAUIQfCIrwA0ERfiAobt1dgM8//7zW5W/durVhLe/219dee23R7XzBq6++2rDW399f6rKRxpofCIrwA0ERfiAowg8ERfiBoAg/EBThB4LKvXV3oQs7Q2/dnTdEd+o4vCQtXry4yHYqNTg4mKzfcsstDWt9fX1FtwMVe+tuAGcgwg8ERfiBoAg/EBThB4Ii/EBQhB8IiuP8FVi4cGGy/vLLL1fUyek++eSTZH39+vXJ+uOPP56sHz58eMw9oTUc5weQRPiBoAg/EBThB4Ii/EBQhB8IivADQeUe5zezWZI2S5ou6YSkHnf/kZl1S/q2pA+zt65191/lfFbI4/xAlUZ7nH804Z8haYa77zCzyZLeknSrpNslHXH3R0fbFOEHyjfa8OeO2OPuA5IGsueHzWyXpEtaaw9A3ca0z29msyV9VdIb2aTVZvZ7M9toZlMazNNlZn1mxj2bgDYy6nP7zezLkn4r6fvu/qyZTZP0kSSXtEFDuwYrcj6DzX6gZIXt80uSmU2UtE3Sr939hyPUZ0va5u7zcj6H8AMlK+zCHjMzSU9I2jU8+NkPgSd9Q9J7Y20SQH1G82v/9ZJelfSuhg71SdJaSUslXa2hzf5+SauyHwdTn8WaHyhZoZv9RSH8QPm4nh9AEuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCo3Bt4FuwjSf8z7PXUbFo7atfe2rUvid6aVWRvl432jZVez3/aws363L2jtgYS2rW3du1Lordm1dUbm/1AUIQfCKru8PfUvPyUdu2tXfuS6K1ZtfRW6z4/gPrUveYHUJNawm9mN5rZbjN738zW1NFDI2bWb2bvmtnOuocYy4ZBO2hm7w2bdqGZ/cbM9maPIw6TVlNv3Wb2v9l3t9PMbq6pt1lm9rKZ7TKzP5jZP2XTa/3uEn3V8r1VvtlvZhMk7ZF0g6T9kt6UtNTd/1hpIw2YWb+kDnev/ZiwmS2UdETS5pOjIZnZv0r62N0fyf5wTnH3f2mT3ro1xpGbS+qt0cjS/6gav7siR7wuQh1r/uskve/uf3L3v0r6uaQlNfTR9tz9FUkfnzJ5iaTe7Hmvhv7zVK5Bb23B3QfcfUf2/LCkkyNL1/rdJfqqRR3hv0TSn4e93q/2GvLbJW03s7fMrKvuZkYw7eTISNnjxTX3c6rckZurdMrI0m3z3TUz4nXR6gj/SKOJtNMhhwXufo2kmyR9J9u8xej8WNJcDQ3jNiDpB3U2k40s/Yyk77r7oTp7GW6Evmr53uoI/35Js4a9ninpQA19jMjdD2SPByVt0dBuSjsZPDlIavZ4sOZ+/p+7D7r7cXc/IeknqvG7y0aWfkbST9392Wxy7d/dSH3V9b3VEf43JV1hZl8xsy9J+pakrTX0cRozm5T9ECMzmyRpsdpv9OGtkpZnz5dLeq7GXr6gXUZubjSytGr+7tptxOtaTvLJDmX8u6QJkja6+/crb2IEZjZHQ2t7aeiKx5/V2ZuZPSVpkYau+hqU9D1J/yHpl5IulbRP0jfdvfIf3hr0tkhjHLm5pN4ajSz9hmr87ooc8bqQfjjDD4iJM/yAoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1f7GaMHSyoBEtAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "im = Image.open('mnist_train/1.png')\n", "plt.imshow(im)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Array Dimensions (28, 28)\n", "\n", "[[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 1 18 38 136 227 255\n", " 254 132 0 90 136 98 3 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 82 156 253 253 253 253 253\n", " 253 249 154 219 253 253 35 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 40 150 244 253 253 253 253 253 253\n", " 253 253 253 253 253 253 35 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 74 237 253 253 253 253 253 203 182 242\n", " 253 253 253 253 253 230 25 0 0 0]\n", " [ 0 0 0 0 0 0 0 13 200 253 253 253 168 164 91 14 64 246\n", " 253 253 253 195 79 32 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 21 219 253 253 159 2 0 0 103 233 253\n", " 253 253 177 10 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 171 253 253 147 0 1 155 250 253 253\n", " 251 126 5 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 101 236 253 206 32 152 253 253 253 253\n", " 130 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 91 253 253 253 253 253 253 241 113\n", " 9 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 91 243 253 253 253 253 239 81 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 207 253 253 253 253 158 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 207 253 253 253 253 121 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 24 145 249 253 253 253 253 194 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 59 253 253 253 253 253 253 224 30 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 5 181 253 253 241 114 240 253 253 136 5\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 36 253 253 253 125 0 65 253 253 253 41\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 67 253 253 253 29 2 138 253 253 253 41\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 60 253 253 253 207 202 253 253 253 192 9\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 5 183 253 253 253 253 253 253 230 52 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 62 253 253 253 253 242 116 13 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]\n", " [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0]]\n" ] } ], "source": [ "import numpy as np\n", "\n", "im_array = np.array(im)\n", "print('Array Dimensions', im_array.shape)\n", "print()\n", "print(im_array)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(256, 2)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Class LabelFile Name
050.png
181.png
282.png
303.png
494.png
\n", "
" ], "text/plain": [ " Class Label File Name\n", "0 5 0.png\n", "1 8 1.png\n", "2 8 2.png\n", "3 0 3.png\n", "4 9 4.png" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train = pd.read_csv('mnist_train.csv')\n", "print(df_train.shape)\n", "df_train.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(256, 2)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Class LabelFile Name
00256.png
18257.png
27258.png
34259.png
47260.png
\n", "
" ], "text/plain": [ " Class Label File Name\n", "0 0 256.png\n", "1 8 257.png\n", "2 7 258.png\n", "3 4 259.png\n", "4 7 260.png" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_valid = pd.read_csv('mnist_valid.csv')\n", "print(df_valid.shape)\n", "df_valid.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(256, 2)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Class LabelFile Name
04512.png
10513.png
26514.png
38515.png
44516.png
\n", "
" ], "text/plain": [ " Class Label File Name\n", "0 4 512.png\n", "1 0 513.png\n", "2 6 514.png\n", "3 8 515.png\n", "4 4 516.png" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test = pd.read_csv('mnist_test.csv')\n", "print(df_test.shape)\n", "df_test.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2) Custom Dataset Class" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "from torch.utils.data import Dataset\n", "import os\n", "\n", "\n", "\n", "class MyDataset(Dataset):\n", "\n", " def __init__(self, csv_path, img_dir, transform=None):\n", " \n", " df = pd.read_csv(csv_path)\n", " self.img_dir = img_dir\n", " self.img_names = df['File Name']\n", " self.y = df['Class Label']\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_names[index]))\n", " \n", " if self.transform is not None:\n", " img = self.transform(img)\n", " \n", " label = self.y[index]\n", " return img, label\n", "\n", " def __len__(self):\n", " return self.y.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3) Custom Dataloader" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "# Note that transforms.ToTensor()\n", "# already divides pixels by 255. internally\n", "\n", "custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.), # not necessary\n", " transforms.ToTensor()\n", " ])\n", "\n", "train_dataset = MyDataset(csv_path='mnist_train.csv',\n", " img_dir='mnist_train',\n", " transform=custom_transform)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=32,\n", " shuffle=True, # want to shuffle the dataset\n", " num_workers=4) # number processes/CPUs to use" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4) Iterating Through the Dataset" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1 | Batch index: 0 | Batch size: 32\n", "Epoch: 1 | Batch index: 1 | Batch size: 32\n", "Epoch: 1 | Batch index: 2 | Batch size: 32\n", "Epoch: 1 | Batch index: 3 | Batch size: 32\n", "Epoch: 1 | Batch index: 4 | Batch size: 32\n", "Epoch: 1 | Batch index: 5 | Batch size: 32\n", "Epoch: 1 | Batch index: 6 | Batch size: 32\n", "Epoch: 1 | Batch index: 7 | Batch size: 32\n", "Epoch: 2 | Batch index: 0 | Batch size: 32\n", "Epoch: 2 | Batch index: 1 | Batch size: 32\n", "Epoch: 2 | Batch index: 2 | Batch size: 32\n", "Epoch: 2 | Batch index: 3 | Batch size: 32\n", "Epoch: 2 | Batch index: 4 | Batch size: 32\n", "Epoch: 2 | Batch index: 5 | Batch size: 32\n", "Epoch: 2 | Batch index: 6 | Batch size: 32\n", "Epoch: 2 | Batch index: 7 | Batch size: 32\n" ] } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "torch.manual_seed(0)\n", "\n", "num_epochs = 2\n", "for epoch in range(num_epochs):\n", "\n", " for batch_idx, (x, y) in enumerate(train_loader):\n", " \n", " print('Epoch:', epoch+1, end='')\n", " print(' | Batch index:', batch_idx, end='')\n", " print(' | Batch size:', y.size()[0])\n", " \n", " x = x.to(device)\n", " y = y.to(device)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 1, 28, 28])\n" ] } ], "source": [ "print(x.shape)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 784])\n" ] } ], "source": [ "x_image_as_vector = x.view(-1, 28*28)\n", "print(x_image_as_vector.shape)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]],\n", "\n", "\n", " ...,\n", "\n", "\n", " [[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]],\n", "\n", "\n", " [[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 4 }