{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "09726bcc-9410-4397-8306-6e4eff4b9b5e", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import seaborn as sns\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms, datasets\n", "import torch.nn.functional as F\n", "from sklearn.metrics import confusion_matrix\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "445fafab-6fa0-4279-b9de-1dcf0483be2a", "metadata": {}, "source": [ "## Download dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "60f60c74-9ca3-4514-a42f-b50389b70e27", "metadata": {}, "outputs": [], "source": [ "ROOT_DATA_DIR = \"FashionMNISTDir\"\n", "\n", "train_data = datasets.FashionMNIST(\n", " root = ROOT_DATA_DIR,\n", " train = True,\n", " download = True,\n", " transform = transforms.ToTensor()\n", " )\n", "\n", "\n", "test_data = datasets.FashionMNIST(\n", " root = ROOT_DATA_DIR,\n", " train = False, ## <<< Test data\n", " download = True,\n", " transform = transforms.ToTensor()\n", " )" ] }, { "cell_type": "code", "execution_count": 3, "id": "c2f04852-f194-4812-a6fe-8557fb145ee1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([60000, 28, 28])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.data.shape" ] }, { "cell_type": "code", "execution_count": 4, "id": "adf58954-e425-47a6-97de-3674a2e209da", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([10000, 28, 28])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data.data.shape" ] }, { "cell_type": "code", "execution_count": 5, "id": "50d627a6-55c8-445f-86ce-2f2202bcdf3e", "metadata": {}, "outputs": [], "source": [ "label_map = {\n", " 0: 'T-shirt/top',\n", " 1: 'Trouser',\n", " 2: 'Pullover',\n", " 3:' Dress',\n", " 4: 'Coat',\n", " 5: 'Sandal',\n", " 6: 'Shirt',\n", " 7: 'Sneaker',\n", " 8: 'Bag',\n", " 9: 'Ankle boot',\n", " }" ] }, { "cell_type": "markdown", "id": "a5ee85fb-9e21-4b3c-a5c6-2a088e80f2e7", "metadata": {}, "source": [ "## Visualize one sample" ] }, { "cell_type": "code", "execution_count": 6, "id": "1cc1b28d-a533-4566-81ae-a3f37471bcfd", "metadata": {}, "outputs": [], "source": [ "def view_sample_img(data, index, label_map):\n", " plt.imshow(data.data[index], cmap=\"gray\")\n", " plt.title(f\"data label: {label_map[data.targets[index].item()]}\")\n", " plt.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "76adb871-bff6-4292-b94c-94e47e8f07b1", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPRUlEQVR4nO3dfYxc5XXH8d+xvWt718FrTAHbrbFqwJZrAQXUmigukZuCkga1hDSmJUj9wxLhD6xWoUWt1BZVcZMKKiWVFZWqEm/uS2RSqBoRUJFqlYogp7Ud4VBiagm8YHDtrdfvxmt4+sdcK8N27znLXo/3zOT7kVbs+sxz7527+9s7O4fnuVZKEYB8Zkz3AQCYGOEEkiKcQFKEE0iKcAJJEU4gqZ4Pp5k9ZmZfybYvM9tmZhumuJ8pj53i/h40sy1O/Ydm9skLdTw/KXo+nB/Fhf6h7zQz+yszO159nDGzsbavv3u+9lNK+blSyjbnOCYMt5n1m9khM5vXa+f+fCCcPayU8qVSyrxSyjxJfybpW+e+LqV8+kIcg5nNcsq/JGlXKeX4hTiWbtNz4TSznzezHWZ2zMy+JWlOW22BmX3HzA6a2eHq85+uapskrZW0ubqybK7+/RtmNmxmR83sP81s7SSPo3ZfbZab2fZq2/9kZhe3jV9jZi+Z2aiZ/eBCvGw0swfM7O3q3P3IzH65rdxvZk9UtR+a2Y1t494ws09Vnz9oZk+Z2RYzOyrpS5L+UNL66rz+oG2bn5H0rHPuP25m3zezI9V/P962z21m9tW689cTSik98yGpX9Kbkn5XUp+kz0sak/SVqr5Q0h2SBiR9TNJWSc+0jd8macO4bX6xGjdL0pclvStpTs3+H/uI+3pb0mpJg5K+LWlLVVsiaUStH94Zkn6l+vqnxh+npKWSRiUtDc7Ng+e2X1NfIWlY0uLq62WSlreNPV0dz0xJX5X0ctvYNyR9qu2xY5J+vTr2uXX7lvSapBUTnXtJF0s6LOnu6tz/ZvX1wuj89cpHr10516gVyq+XUsZKKU9J+v65YillpJTy7VLKyVLKMUmbJN3sbbCUsqUad7aU8heSZqv1g+ya5L6eLKXsLqWckPRHkr5gZjPV+oXwbCnl2VLKB6WUf5H0H2qFY/x+9pVShkop+6JjCrxfPbdVZtZXSnmjlLK3rf7v1fG8L+lJSdc62/peKeWZ6thPTfQAM1suaVYp5Uc12/hVSa+XUp6szv3fqxXm29oeU3f+ekKvhXOxpLdL9au18ua5T8xswMweMbM3q5dc/yZpyPuGmtn9ZvZf1UurUUnzJV0SHcgk9zU87jj7qm1fIek3qpe0o9V+PyFpUbTfyTKz77a9OXRXKeW/Jf2OWle5/zGzfzCzxW1D3m37/KSkOc7fk8M1/97uM5K8N6UWq+17V3lTrVcVE+2n/fz1hF4L5zuSlpiZtf3b0rbPv6zWVe8XSykXqfWGhCSde/yHpuhUf1/+vqQvSFpQShmSdKTt8Z5oX5L0M+OOc0zSIbV+6J6srojnPgZLKV+bxH4npZTy6fLjN4f+tvq3vyulfEKtXw5F0p9PdfPB11L196bzmP3VcbRbqtZL2XPqzl9P6LVwfk/SWUkbzazPzD4n6Rfa6h+TdErSaPXmwZ+MG39A0s+Oe/xZSQclzTKzP5Z00SSPJdqXJH3RzFaZ2YCkP5X0VPWycYuk28zsVjObaWZzzOyTE7yhdN6Y2QozW2dms9X6+/KUpA/O0+YPSFpmZjOqfQ2o9X3513GPaT/3z0q62sx+y8xmmdl6SaskfaftMXXnryf0VDhLKWckfU7Sb0v6X0nrJf1j20O+rtYbFIckvSzpuXGb+Iakz1fvrv6lpOerx+xR62XTaU3uJdtk9iW1/nZ7TNWbTJI2Vs9jWNKvqfUu58Fqn7+nCb5fZra0emm6dHztI5ot6WvV8b4r6VJJf9Bwm+dsrf47YmY7JK1T6+/S022P+dC5L6WMSPqsWq9ARtR6BfPZUkr7lXHC89cr7MN/ngGdZ2bflLS7lPLNBtvYpta7s39z3g4sGa9BDHTKLkn/PN0HkR3hxAVXSvnr6T6GbsDLWiCpnnpDCOgl7staM0t7Wf1wK/P/m85XBCtXrnTrmzdvrq1t3bq1tiZJO3fudOtnzpxx62NjY2599erVtbXbb7/dHbt37163/tBDD7n10dFRt96rSikT/jBz5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpNz/Q6iTfc7p7FNed911bv3OO+9063fccYdbf/99f9bS4OBgbW3u3Lnu2IULF7r1TtqzZ49b/+ADf4bZihX+AhIHDhyorT3//PPu2Icfftit7969261PJ/qcQJchnEBShBNIinACSRFOICnCCSRFOIGkpq3P2dRFF/krVD7xxBO1tWuuucYdO2OG/zvr2LFjbv306dNu3ZtTGfVI+/r63Pr8+fPd+okTJ9y616vs9BzZOXPm1Nai/m9/f79bf/HFF9363Xff7dY7iT4n0GUIJ5AU4QSSIpxAUoQTSIpwAkl1bSvlhRdecOtXXDH+7nE/NjIy4o6Npj7NmuUvlH/27Fm3Hk2X80RtnmhpzJkzp35v2WjfndR0iuGiRf6tTW+99Va3/tprr7n1JmilAF2GcAJJEU4gKcIJJEU4gaQIJ5AU4QSSSntn6xtuuMGte31MSTp06FBtLepTRr1Ab2qTJC1ZssStDwwM1NaiXmJ0C7/ouUVT0rx+YjRdLervRlPt3nrrrSlvOxI97w0bNrj1+++/v9H+p4IrJ5AU4QSSIpxAUoQTSIpwAkkRTiApwgkklXY+Z9RX2rhxo1v3+pzRfM2ozxn1zB555BG3vn///tqa1+uTpMWLF7v1d955x603mQ86e/Zsd+y8efPc+vXXX+/W77vvvtqa9/2U4v5utJRqNH7ZsmVuvQnmcwJdhnACSRFOICnCCSRFOIGkCCeQFOEEkkrb53z55Zfd+qWXXurWvbmD0dquUb/uyJEjbn3NmjVu/ZZbbqmtRXNBH330Ubd+zz33uPXdu3e7de9We1H/98CBA259165dbv3111+vrUVzQaM5ttF80JUrV7r11atX19b27Nnjjo3Q5wS6DOEEkiKcQFKEE0iKcAJJEU4gqbRLY1577bVufXh42K17U6OiqU+RaPpR5LnnnqutnThxwh27atUqtx5NtXv66afd+m233VZbi6ZV7dixw61Hy5167Y7BwUF3bDSNL5omuG/fPrd+00031daatlLqcOUEkiKcQFKEE0iKcAJJEU4gKcIJJEU4gaSmrc/pTcGRpIMHD7r1aAqQN73Ju82d5E+bkqSRkRG3HvGe+3vvveeOXbRokVvftGmTW4+eu3eLwWis1wucDG/J0GgqXdM+56lTp9z62rVra2uPP/64O3aquHICSRFOICnCCSRFOIGkCCeQFOEEkiKcQFLT1ud84IEH3HrUazx+/Lhb9/pe0bZPnz7t1qMe64033ujWFy5cWFu7+OKL3bF9fX1u/bLLLnPrXh9T8p97f3+/O3ZoaMitr1+/3q0vWLCgthb1IefPn+/Wo/HRc4u+p53AlRNIinACSRFOICnCCSRFOIGkCCeQFOEEkpq2PudLL73k1i+//HK3fuWVV7p1b23ZaA1U71Z0Ujx3MLp9oTe3MJp3GO07uk1ftPasN2cz2re3VrAU38bPW/91YGDAHRs97+jYvLmkkvTMM8+49U7gygkkRTiBpAgnkBThBJIinEBShBNIinACSVkppb5oVl+cZt7cP0m66qqramv33nuvO/bmm29269G9QaO5haOjo7W1aL5m1M/rpGjd2qiXGM2T9c7bK6+84o6966673HpmpZQJTyxXTiApwgkkRTiBpAgnkBThBJIinEBS0zZlrKnDhw+79e3bt9fWotvsrVu3zq177ScpXmbRm7IWtUqiKWWRqB3i1aN9z549262fOXPGrc+ZM6e2Fk0x7EVcOYGkCCeQFOEEkiKcQFKEE0iKcAJJEU4gqbR9zqgfF02t8npqUZ/y6NGjbj3qRUZLSEb790Tnpcm2O63JdDdvmt352HfUw52O88qVE0iKcAJJEU4gKcIJJEU4gaQIJ5AU4QSSStvnjPpKY2NjU9723r173XrU54xuoxfNW/REz7vTfc5o+57oeUe9aU/0PYlEy3ZGvenpwJUTSIpwAkkRTiApwgkkRTiBpAgnkBThBJJK2+eMNOlbnTp1yh0b9eui9VnPnj3r1r0+adM+ZpN1aSX/vEb7jtYDHhgYcOvesUXntBdx5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9lk3mK0RmnTdWejetSj9UTH3mRtWMnvNUbHHT3v6Nib9FgjmdfzrcOVE0iKcAJJEU4gKcIJJEU4gaQIJ5BU17ZSOmnJkiVu/fDhw249amd4b+tH7YomS1d2WnTs0XKm3nNr2iLqRlw5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiCpru1zdnIKUNNlGPv7+926NyWt6dKWnVxaM5ryFd3iL1o60zu2JrcPjLadFVdOICnCCSRFOIGkCCeQFOEEkiKcQFKEE0iqa/ucnRT146K5hVGf1Bsf9RKjfl10bNHtDb3te7cujMZK0smTJ926Z2hoaMpjuxVXTiApwgkkRTiBpAgnkBThBJIinEBShBNIij7nBKJeY1PenMmm8w47ue5tk7mgkxnv9Yfnzp3rjo0wnxPAeUM4gaQIJ5AU4QSSIpxAUoQTSIpWygSidkRTnXxbfzpbKdG+m7RSBgYG3LG9iCsnkBThBJIinEBShBNIinACSRFOICnCCSTVtX3O6ZwCFC0/2UTTaVmRJsfe6els3q0RO3nOs+LKCSRFOIGkCCeQFOEEkiKcQFKEE0iKcAJJdW2fs+kyjJ7oNnmdnFsYLcvZ9PaDnTxvTXWyz8nSmADOG8IJJEU4gaQIJ5AU4QSSIpxAUoQTSKpr+5zTqcm8RMnvNUbbblqP+qjTua6th/mcANIgnEBShBNIinACSRFOICnCCSRFOIGkurbP2cn5efv373frV199tVuP5lR6vcaoD9nX1zflbU+m7p3XqH87a1azHydv38znBJAG4QSSIpxAUoQTSIpwAkkRTiCprm2ldNLQ0JBbHxwcdOtRS+GSSy6prTWdEha1WpqIWilRu2N4eNite0uOLl++3B0baTqVbjpw5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9nJW9nt3LnTrb/66qtufXR01K036UVG/brjx4+79ei8eOe1yVQ4Kb614oIFC2pr27dvd8dGMvYxI1w5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiAp68YlA4GfBFw5gaQIJ5AU4QSSIpxAUoQTSIpwAkn9H5vkccLt/ncCAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "view_sample_img(train_data, index=1, label_map=label_map)" ] }, { "cell_type": "markdown", "id": "13642d17-c534-40fb-86e7-e5234a6127d0", "metadata": {}, "source": [ "## Create the dataloader" ] }, { "cell_type": "code", "execution_count": 8, "id": "4b7a8c70-14d0-4a5f-9700-0fd317929074", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 64\n", "\n", "train_data_loader = DataLoader(\n", " dataset = train_data,\n", " batch_size = BATCH_SIZE,\n", " shuffle = True\n", " )\n", "\n", "test_data_loader = DataLoader(\n", " dataset = test_data,\n", " batch_size = BATCH_SIZE,\n", " shuffle = True\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "id": "c216a67f-9aab-4013-a270-5dbc8cbd8ff5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 1, 28, 28])\n", "torch.Size([64])\n" ] } ], "source": [ "for data, label in test_data_loader:\n", " print(data.shape) \n", " print(label.shape)\n", " break" ] }, { "cell_type": "markdown", "id": "697498b5-fbd2-47e1-8823-c983cc9c85c2", "metadata": {}, "source": [ "## CNN architecture\n", "\n", "pytorch doc - [reference](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d)" ] }, { "cell_type": "code", "execution_count": 10, "id": "a837c0ec-c0ef-4bc9-96e4-e1cbc52a44d2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'cuda'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "device" ] }, { "cell_type": "code", "execution_count": 11, "id": "94510485-d718-4912-95c9-1f06f22257ad", "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self, in_, out_):\n", " super(CNN, self).__init__()\n", " \n", " self.conv_pool_01 = nn.Sequential(\n", " nn.Conv2d(in_channels=in_, out_channels=8, kernel_size=5, stride=1, padding=0),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2)\n", " )\n", " \n", " self.conv_pool_02 = nn.Sequential(\n", " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=0),\n", " nn.ReLU(),\n", " nn.MaxPool2d(kernel_size=2, stride=2)\n", " )\n", " \n", " self.Flatten = nn.Flatten()\n", " self.FC_01 = nn.Linear(in_features=16*4*4, out_features=128)\n", " self.FC_02 = nn.Linear(in_features=128, out_features=64)\n", " self.FC_03 = nn.Linear(in_features=64, out_features=out_)\n", " \n", " \n", " def forward(self, x):\n", " x = self.conv_pool_01(x)\n", " x = self.conv_pool_02(x)\n", " x = self.Flatten(x)\n", " x = self.FC_01(x)\n", " x = self.FC_02(x) \n", " x = self.FC_03(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 12, "id": "62ee8c3b-01e2-4c06-b3a0-25ababc457e6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CNN(\n", " (conv_pool_01): Sequential(\n", " (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (conv_pool_02): Sequential(\n", " (0): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (1): ReLU()\n", " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (Flatten): Flatten(start_dim=1, end_dim=-1)\n", " (FC_01): Linear(in_features=256, out_features=128, bias=True)\n", " (FC_02): Linear(in_features=128, out_features=64, bias=True)\n", " (FC_03): Linear(in_features=64, out_features=10, bias=True)\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = CNN(1, 10)\n", "model.to(device)" ] }, { "cell_type": "markdown", "id": "1ce4336d-8d13-4a8d-a586-393229fe248d", "metadata": {}, "source": [ "## Count no. of trainable params" ] }, { "cell_type": "code", "execution_count": 13, "id": "81b2554a-c504-442f-9363-7a6b51ed1364", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Total trainable parameters: 45226
 ModulesParameters
0conv_pool_01.0.weight200
1conv_pool_01.0.bias8
2conv_pool_02.0.weight3200
3conv_pool_02.0.bias16
4FC_01.weight32768
5FC_01.bias128
6FC_02.weight8192
7FC_02.bias64
8FC_03.weight640
9FC_03.bias10
\n" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def count_params(model):\n", " model_params = {\"Modules\": list(), \"Parameters\": list()}\n", " total = 0\n", " for name, parameters in model.named_parameters():\n", " if not parameters.requires_grad:\n", " continue\n", " param = parameters.numel()\n", " model_params[\"Modules\"].append(name)\n", " model_params[\"Parameters\"].append(param)\n", " total += param\n", " df = pd.DataFrame(model_params)\n", " df = df.style.set_caption(f\"Total trainable parameters: {total}\")\n", " return df\n", "\n", "count_params(model)" ] }, { "cell_type": "code", "execution_count": 14, "id": "f9c79c55-2bc0-416b-adfd-c5f6d05ad02d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "next(model.parameters()).is_cuda" ] }, { "cell_type": "markdown", "id": "27898820-48dc-4d82-a5b3-33d431baf79e", "metadata": {}, "source": [ "## Traning loop" ] }, { "cell_type": "code", "execution_count": 15, "id": "4a5d7fad-e433-49f1-a06a-fe48fb7684fd", "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.001\n", "num_epochs = 20" ] }, { "cell_type": "code", "execution_count": 16, "id": "9124fc31-1b9d-4a16-995d-9d2dd178238f", "metadata": {}, "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)" ] }, { "cell_type": "code", "execution_count": 17, "id": "ee45f38a-da2b-4835-9507-3f368198eb41", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "938" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_total_steps = len(train_data_loader)\n", "n_total_steps" ] }, { "cell_type": "code", "execution_count": 18, "id": "9b3fe689-eaba-4c86-b221-48cc89428248", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "937.5" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "60000/BATCH_SIZE" ] }, { "cell_type": "code", "execution_count": 19, "id": "5860758d-95b1-42c2-a285-544254fb9936", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/20: 100%|██████████| 938/938 [00:20<00:00, 46.28it/s, loss=0.405]\n", "Epoch 2/20: 100%|██████████| 938/938 [00:14<00:00, 62.66it/s, loss=0.461]\n", "Epoch 3/20: 100%|██████████| 938/938 [00:15<00:00, 58.97it/s, loss=0.279]\n", "Epoch 4/20: 100%|██████████| 938/938 [00:15<00:00, 59.80it/s, loss=0.144] \n", "Epoch 5/20: 100%|██████████| 938/938 [00:15<00:00, 61.09it/s, loss=0.336]\n", "Epoch 6/20: 100%|██████████| 938/938 [00:15<00:00, 60.48it/s, loss=0.443] \n", "Epoch 7/20: 100%|██████████| 938/938 [00:15<00:00, 60.35it/s, loss=0.334] \n", "Epoch 8/20: 100%|██████████| 938/938 [00:15<00:00, 60.15it/s, loss=0.413] \n", "Epoch 9/20: 100%|██████████| 938/938 [00:15<00:00, 59.85it/s, loss=0.209] \n", "Epoch 10/20: 100%|██████████| 938/938 [00:15<00:00, 60.09it/s, loss=0.228] \n", "Epoch 11/20: 100%|██████████| 938/938 [00:15<00:00, 58.82it/s, loss=0.212] \n", "Epoch 12/20: 100%|██████████| 938/938 [00:15<00:00, 59.74it/s, loss=0.0203]\n", "Epoch 13/20: 100%|██████████| 938/938 [00:15<00:00, 59.96it/s, loss=0.437] \n", "Epoch 14/20: 100%|██████████| 938/938 [00:15<00:00, 60.11it/s, loss=0.38] \n", "Epoch 15/20: 100%|██████████| 938/938 [00:15<00:00, 60.20it/s, loss=0.292] \n", "Epoch 16/20: 100%|██████████| 938/938 [00:15<00:00, 59.15it/s, loss=0.177] \n", "Epoch 17/20: 100%|██████████| 938/938 [00:16<00:00, 58.12it/s, loss=0.381] \n", "Epoch 18/20: 100%|██████████| 938/938 [00:16<00:00, 57.91it/s, loss=0.173] \n", "Epoch 19/20: 100%|██████████| 938/938 [00:16<00:00, 57.57it/s, loss=0.276] \n", "Epoch 20/20: 100%|██████████| 938/938 [00:16<00:00, 57.84it/s, loss=0.233] \n" ] } ], "source": [ "for epoch in range(num_epochs):\n", " with tqdm(train_data_loader) as tqdm_epoch:\n", " for images, labels in tqdm_epoch:\n", " tqdm_epoch.set_description(f\"Epoch {epoch + 1}/{num_epochs}\")\n", " \n", " images = images.to(device)\n", " labels = labels.to(device) \n", " \n", " # forward pass\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", " \n", " # backward prop\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " tqdm_epoch.set_postfix(loss=loss.item())\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 20, "id": "d79162ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'d:\\\\oneNeuron\\\\Pytorch\\\\Pytorch-basics\\\\codebase'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "os.getcwd()" ] }, { "cell_type": "code", "execution_count": 21, "id": "226b3c19", "metadata": {}, "outputs": [], "source": [ "## save trained model -\n", "os.makedirs(\"06_03_session_dir\", exist_ok=True)\n", "modle_file = os.path.join(\"06_03_session_dir\", 'CNN_model.pth')\n", "torch.save(model, modle_file)" ] }, { "cell_type": "code", "execution_count": null, "id": "1a640fbb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 5 }