{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.1.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Ordinal Regression CNN -- CORAL" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of a method for ordinal regression, CORAL [1] (COnsistent RAnk Logits) applied to predicting age from face images in the AFAD [2] (Asian Face) dataset using a simple ResNet-34 [3] convolutional network architecture.\n", "\n", "Note that in order to reduce training time, only a subset of AFAD (AFAD-Lite) is being used.\n", "\n", "- [1] Cao, Wenzhi, Vahid Mirjalili, and Sebastian Raschka. \"[Consistent Rank Logits for Ordinal Regression with Convolutional Neural Networks](https://arxiv.org/abs/1901.07884).\" arXiv preprint arXiv:1901.07884 (2019).\n", "- [2] Niu, Zhenxing, Mo Zhou, Le Wang, Xinbo Gao, and Gang Hua. \"[Ordinal regression with multiple output cnn for age estimation](https://ieeexplore.ieee.org/document/7780901/).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4920-4928. 2016.\n", "- [3] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \"[Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch\n", "\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "from PIL import Image\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Downloading the Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'tarball-lite'...\n", "remote: Enumerating objects: 37, done.\u001b[K\n", "remote: Total 37 (delta 0), reused 0 (delta 0), pack-reused 37\u001b[K\n", "Unpacking objects: 100% (37/37), done.\n", "Checking out files: 100% (30/30), done.\n" ] } ], "source": [ "!git clone https://github.com/afad-dataset/tarball-lite.git" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "!cat tarball-lite/AFAD-Lite.tar.xz* > tarball-lite/AFAD-Lite.tar.xz" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "!tar xf tarball-lite/AFAD-Lite.tar.xz" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "rootDir = 'AFAD-Lite'\n", "\n", "files = [os.path.relpath(os.path.join(dirpath, file), rootDir)\n", " for (dirpath, dirnames, filenames) in os.walk(rootDir) \n", " for file in filenames if file.endswith('.jpg')]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "59344" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(files)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "d = {}\n", "\n", "d['age'] = []\n", "d['gender'] = []\n", "d['file'] = []\n", "d['path'] = []\n", "\n", "for f in files:\n", " age, gender, fname = f.split('/')\n", " if gender == '111':\n", " gender = 'male'\n", " else:\n", " gender = 'female'\n", " \n", " d['age'].append(age)\n", " d['gender'].append(gender)\n", " d['file'].append(fname)\n", " d['path'].append(f)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agegenderfilepath
039female474596-0.jpg39/112/474596-0.jpg
139female397477-0.jpg39/112/397477-0.jpg
239female576466-0.jpg39/112/576466-0.jpg
339female399405-0.jpg39/112/399405-0.jpg
439female410524-0.jpg39/112/410524-0.jpg
\n", "
" ], "text/plain": [ " age gender file path\n", "0 39 female 474596-0.jpg 39/112/474596-0.jpg\n", "1 39 female 397477-0.jpg 39/112/397477-0.jpg\n", "2 39 female 576466-0.jpg 39/112/576466-0.jpg\n", "3 39 female 399405-0.jpg 39/112/399405-0.jpg\n", "4 39 female 410524-0.jpg 39/112/410524-0.jpg" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame.from_dict(d)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'18'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['age'].min()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "df['age'] = df['age'].values.astype(int) - 18" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "np.random.seed(123)\n", "msk = np.random.rand(len(df)) < 0.8\n", "df_train = df[msk]\n", "df_test = df[~msk]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "df_train.set_index('file', inplace=True)\n", "df_train.to_csv('training_set_lite.csv')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "df_test.set_index('file', inplace=True)\n", "df_test.to_csv('test_set_lite.csv')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "22\n" ] } ], "source": [ "num_ages = np.unique(df['age'].values).shape[0]\n", "print(num_ages)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "DEVICE = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "NUM_WORKERS = 4\n", "\n", "NUM_CLASSES = num_ages\n", "BATCH_SIZE = 512\n", "NUM_EPOCHS = 150\n", "LEARNING_RATE = 0.0005\n", "RANDOM_SEED = 123\n", "GRAYSCALE = False\n", "\n", "TRAIN_CSV_PATH = 'training_set_lite.csv'\n", "TEST_CSV_PATH = 'test_set_lite.csv'\n", "IMAGE_PATH = 'AFAD-Lite'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset Loaders" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class AFADDatasetAge(Dataset):\n", " \"\"\"Custom Dataset for loading AFAD face images\"\"\"\n", "\n", " def __init__(self, csv_path, img_dir, transform=None):\n", "\n", " df = pd.read_csv(csv_path, index_col=0)\n", " self.img_dir = img_dir\n", " self.csv_path = csv_path\n", " self.img_paths = df['path']\n", " self.y = df['age'].values\n", " self.transform = transform\n", "\n", " def __getitem__(self, index):\n", " img = Image.open(os.path.join(self.img_dir,\n", " self.img_paths[index]))\n", "\n", " if self.transform is not None:\n", " img = self.transform(img)\n", "\n", " label = self.y[index]\n", " levels = [1]*label + [0]*(NUM_CLASSES - 1 - label)\n", " levels = torch.tensor(levels, dtype=torch.float32)\n", "\n", " return img, label, levels\n", "\n", " def __len__(self):\n", " return self.y.shape[0]\n", "\n", "\n", "custom_transform = transforms.Compose([transforms.Resize((128, 128)),\n", " transforms.RandomCrop((120, 120)),\n", " transforms.ToTensor()])\n", "\n", "train_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH,\n", " img_dir=IMAGE_PATH,\n", " transform=custom_transform)\n", "\n", "\n", "custom_transform2 = transforms.Compose([transforms.Resize((128, 128)),\n", " transforms.CenterCrop((120, 120)),\n", " transforms.ToTensor()])\n", "\n", "test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH,\n", " img_dir=IMAGE_PATH,\n", " transform=custom_transform2)\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=NUM_WORKERS)\n", "\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " num_workers=NUM_WORKERS)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "##########################\n", "# MODEL\n", "##########################\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", " padding=1, bias=False)\n", "\n", "\n", "class BasicBlock(nn.Module):\n", " expansion = 1\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None):\n", " super(BasicBlock, self).__init__()\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = nn.BatchNorm2d(planes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = nn.BatchNorm2d(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " residual = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " residual = self.downsample(x)\n", "\n", " out += residual\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "class ResNet(nn.Module):\n", "\n", " def __init__(self, block, layers, num_classes, grayscale):\n", " self.num_classes = num_classes\n", " self.inplanes = 64\n", " if grayscale:\n", " in_dim = 1\n", " else:\n", " in_dim = 3\n", " super(ResNet, self).__init__()\n", " self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n", " bias=False)\n", " self.bn1 = nn.BatchNorm2d(64)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n", " self.avgpool = nn.AvgPool2d(7, stride=1, padding=2)\n", " self.fc = nn.Linear(2048 * block.expansion, 1, bias=False)\n", " self.linear_1_bias = nn.Parameter(torch.zeros(self.num_classes-1).float())\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", " m.weight.data.normal_(0, (2. / n)**.5)\n", " elif isinstance(m, nn.BatchNorm2d):\n", " m.weight.data.fill_(1)\n", " m.bias.data.zero_()\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1):\n", " downsample = None\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " nn.Conv2d(self.inplanes, planes * block.expansion,\n", " kernel_size=1, stride=stride, bias=False),\n", " nn.BatchNorm2d(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample))\n", " self.inplanes = planes * block.expansion\n", " for i in range(1, blocks):\n", " layers.append(block(self.inplanes, planes))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = x.view(x.size(0), -1)\n", " logits = self.fc(x)\n", " logits = logits + self.linear_1_bias\n", " probas = torch.sigmoid(logits)\n", " return logits, probas\n", "\n", "\n", "def resnet34(num_classes, grayscale):\n", " \"\"\"Constructs a ResNet-34 model.\"\"\"\n", " model = ResNet(block=BasicBlock,\n", " layers=[3, 4, 6, 3],\n", " num_classes=num_classes,\n", " grayscale=grayscale)\n", " return model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "###########################################\n", "# Initialize Cost, Model, and Optimizer\n", "###########################################\n", "\n", "def cost_fn(logits, levels):\n", " val = (-torch.sum((F.logsigmoid(logits)*levels\n", " + (F.logsigmoid(logits) - logits)*(1-levels)),\n", " dim=1))\n", " return torch.mean(val)\n", "\n", "\n", "torch.manual_seed(RANDOM_SEED)\n", "torch.cuda.manual_seed(RANDOM_SEED)\n", "model = resnet34(NUM_CLASSES, GRAYSCALE)\n", "\n", "model.to(DEVICE)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/150 | Batch 0000/0092 | Cost: 15.0424\n", "Time elapsed: 0.91 min\n", "Epoch: 002/150 | Batch 0000/0092 | Cost: 12.5222\n", "Time elapsed: 1.83 min\n", "Epoch: 003/150 | Batch 0000/0092 | Cost: 12.0170\n", "Time elapsed: 2.77 min\n", "Epoch: 004/150 | Batch 0000/0092 | Cost: 11.6722\n", "Time elapsed: 3.71 min\n", "Epoch: 005/150 | Batch 0000/0092 | Cost: 11.2609\n", "Time elapsed: 4.65 min\n", "Epoch: 006/150 | Batch 0000/0092 | Cost: 10.9205\n", "Time elapsed: 5.59 min\n", "Epoch: 007/150 | Batch 0000/0092 | Cost: 11.2049\n", "Time elapsed: 6.54 min\n", "Epoch: 008/150 | Batch 0000/0092 | Cost: 10.4912\n", "Time elapsed: 7.50 min\n", "Epoch: 009/150 | Batch 0000/0092 | Cost: 10.2098\n", "Time elapsed: 8.46 min\n", "Epoch: 010/150 | Batch 0000/0092 | Cost: 10.0003\n", "Time elapsed: 9.41 min\n", "Epoch: 011/150 | Batch 0000/0092 | Cost: 9.9253\n", "Time elapsed: 10.36 min\n", "Epoch: 012/150 | Batch 0000/0092 | Cost: 9.5460\n", "Time elapsed: 11.31 min\n", "Epoch: 013/150 | Batch 0000/0092 | Cost: 9.3959\n", "Time elapsed: 12.26 min\n", "Epoch: 014/150 | Batch 0000/0092 | Cost: 9.2571\n", "Time elapsed: 13.21 min\n", "Epoch: 015/150 | Batch 0000/0092 | Cost: 8.9584\n", "Time elapsed: 14.16 min\n", "Epoch: 016/150 | Batch 0000/0092 | Cost: 8.7167\n", "Time elapsed: 15.13 min\n", "Epoch: 017/150 | Batch 0000/0092 | Cost: 8.5358\n", "Time elapsed: 16.09 min\n", "Epoch: 018/150 | Batch 0000/0092 | Cost: 8.2750\n", "Time elapsed: 17.05 min\n", "Epoch: 019/150 | Batch 0000/0092 | Cost: 8.0949\n", "Time elapsed: 18.02 min\n", "Epoch: 020/150 | Batch 0000/0092 | Cost: 7.8891\n", "Time elapsed: 19.00 min\n", "Epoch: 021/150 | Batch 0000/0092 | Cost: 7.7796\n", "Time elapsed: 19.96 min\n", "Epoch: 022/150 | Batch 0000/0092 | Cost: 7.5751\n", "Time elapsed: 20.92 min\n", "Epoch: 023/150 | Batch 0000/0092 | Cost: 7.4735\n", "Time elapsed: 21.88 min\n", "Epoch: 024/150 | Batch 0000/0092 | Cost: 7.3455\n", "Time elapsed: 22.84 min\n", "Epoch: 025/150 | Batch 0000/0092 | Cost: 7.1206\n", "Time elapsed: 23.80 min\n", "Epoch: 026/150 | Batch 0000/0092 | Cost: 7.0157\n", "Time elapsed: 24.75 min\n", "Epoch: 027/150 | Batch 0000/0092 | Cost: 6.8013\n", "Time elapsed: 25.71 min\n", "Epoch: 028/150 | Batch 0000/0092 | Cost: 6.6981\n", "Time elapsed: 26.67 min\n", "Epoch: 029/150 | Batch 0000/0092 | Cost: 6.6510\n", "Time elapsed: 27.62 min\n", "Epoch: 030/150 | Batch 0000/0092 | Cost: 6.4859\n", "Time elapsed: 28.58 min\n", "Epoch: 031/150 | Batch 0000/0092 | Cost: 6.3101\n", "Time elapsed: 29.54 min\n", "Epoch: 032/150 | Batch 0000/0092 | Cost: 6.2179\n", "Time elapsed: 30.49 min\n", "Epoch: 033/150 | Batch 0000/0092 | Cost: 6.2418\n", "Time elapsed: 31.45 min\n", "Epoch: 034/150 | Batch 0000/0092 | Cost: 6.0992\n", "Time elapsed: 32.41 min\n", "Epoch: 035/150 | Batch 0000/0092 | Cost: 5.9214\n", "Time elapsed: 33.37 min\n", "Epoch: 036/150 | Batch 0000/0092 | Cost: 5.8149\n", "Time elapsed: 34.33 min\n", "Epoch: 037/150 | Batch 0000/0092 | Cost: 5.7312\n", "Time elapsed: 35.30 min\n", "Epoch: 038/150 | Batch 0000/0092 | Cost: 5.6387\n", "Time elapsed: 36.28 min\n", "Epoch: 039/150 | Batch 0000/0092 | Cost: 5.5805\n", "Time elapsed: 37.24 min\n", "Epoch: 040/150 | Batch 0000/0092 | Cost: 5.3195\n", "Time elapsed: 38.20 min\n", "Epoch: 041/150 | Batch 0000/0092 | Cost: 5.5065\n", "Time elapsed: 39.16 min\n", "Epoch: 042/150 | Batch 0000/0092 | Cost: 5.6153\n", "Time elapsed: 40.13 min\n", "Epoch: 043/150 | Batch 0000/0092 | Cost: 5.2801\n", "Time elapsed: 41.10 min\n", "Epoch: 044/150 | Batch 0000/0092 | Cost: 5.2717\n", "Time elapsed: 42.07 min\n", "Epoch: 045/150 | Batch 0000/0092 | Cost: 5.1263\n", "Time elapsed: 43.06 min\n", "Epoch: 046/150 | Batch 0000/0092 | Cost: 5.0700\n", "Time elapsed: 44.03 min\n", "Epoch: 047/150 | Batch 0000/0092 | Cost: 5.1728\n", "Time elapsed: 45.01 min\n", "Epoch: 048/150 | Batch 0000/0092 | Cost: 5.0284\n", "Time elapsed: 45.98 min\n", "Epoch: 049/150 | Batch 0000/0092 | Cost: 4.9178\n", "Time elapsed: 46.95 min\n", "Epoch: 050/150 | Batch 0000/0092 | Cost: 5.0401\n", "Time elapsed: 47.93 min\n", "Epoch: 051/150 | Batch 0000/0092 | Cost: 4.7706\n", "Time elapsed: 48.92 min\n", "Epoch: 052/150 | Batch 0000/0092 | Cost: 4.8608\n", "Time elapsed: 49.90 min\n", "Epoch: 053/150 | Batch 0000/0092 | Cost: 4.7105\n", "Time elapsed: 50.87 min\n", "Epoch: 054/150 | Batch 0000/0092 | Cost: 4.7156\n", "Time elapsed: 51.85 min\n", "Epoch: 055/150 | Batch 0000/0092 | Cost: 4.6754\n", "Time elapsed: 52.81 min\n", "Epoch: 056/150 | Batch 0000/0092 | Cost: 4.5800\n", "Time elapsed: 53.79 min\n", "Epoch: 057/150 | Batch 0000/0092 | Cost: 4.4490\n", "Time elapsed: 54.76 min\n", "Epoch: 058/150 | Batch 0000/0092 | Cost: 4.4306\n", "Time elapsed: 55.74 min\n", "Epoch: 059/150 | Batch 0000/0092 | Cost: 4.4310\n", "Time elapsed: 56.70 min\n", "Epoch: 060/150 | Batch 0000/0092 | Cost: 4.4331\n", "Time elapsed: 57.67 min\n", "Epoch: 061/150 | Batch 0000/0092 | Cost: 4.2809\n", "Time elapsed: 58.64 min\n", "Epoch: 062/150 | Batch 0000/0092 | Cost: 4.3698\n", "Time elapsed: 59.62 min\n", "Epoch: 063/150 | Batch 0000/0092 | Cost: 4.3086\n", "Time elapsed: 60.59 min\n", "Epoch: 064/150 | Batch 0000/0092 | Cost: 4.2474\n", "Time elapsed: 61.57 min\n", "Epoch: 065/150 | Batch 0000/0092 | Cost: 4.2255\n", "Time elapsed: 62.53 min\n", "Epoch: 066/150 | Batch 0000/0092 | Cost: 4.1545\n", "Time elapsed: 63.52 min\n", "Epoch: 067/150 | Batch 0000/0092 | Cost: 4.1680\n", "Time elapsed: 64.49 min\n", "Epoch: 068/150 | Batch 0000/0092 | Cost: 4.1133\n", "Time elapsed: 65.46 min\n", "Epoch: 069/150 | Batch 0000/0092 | Cost: 4.0342\n", "Time elapsed: 66.42 min\n", "Epoch: 070/150 | Batch 0000/0092 | Cost: 4.1035\n", "Time elapsed: 67.38 min\n", "Epoch: 071/150 | Batch 0000/0092 | Cost: 4.0500\n", "Time elapsed: 68.34 min\n", "Epoch: 072/150 | Batch 0000/0092 | Cost: 3.8781\n", "Time elapsed: 69.31 min\n", "Epoch: 073/150 | Batch 0000/0092 | Cost: 3.8854\n", "Time elapsed: 70.29 min\n", "Epoch: 074/150 | Batch 0000/0092 | Cost: 3.9859\n", "Time elapsed: 71.25 min\n", "Epoch: 075/150 | Batch 0000/0092 | Cost: 4.0262\n", "Time elapsed: 72.22 min\n", "Epoch: 076/150 | Batch 0000/0092 | Cost: 4.3140\n", "Time elapsed: 73.21 min\n", "Epoch: 077/150 | Batch 0000/0092 | Cost: 4.1002\n", "Time elapsed: 74.20 min\n", "Epoch: 078/150 | Batch 0000/0092 | Cost: 3.9676\n", "Time elapsed: 75.19 min\n", "Epoch: 079/150 | Batch 0000/0092 | Cost: 3.6617\n", "Time elapsed: 76.18 min\n", "Epoch: 080/150 | Batch 0000/0092 | Cost: 3.7342\n", "Time elapsed: 77.15 min\n", "Epoch: 081/150 | Batch 0000/0092 | Cost: 3.5710\n", "Time elapsed: 78.12 min\n", "Epoch: 082/150 | Batch 0000/0092 | Cost: 3.6218\n", "Time elapsed: 79.08 min\n", "Epoch: 083/150 | Batch 0000/0092 | Cost: 3.4883\n", "Time elapsed: 80.04 min\n", "Epoch: 084/150 | Batch 0000/0092 | Cost: 3.5037\n", "Time elapsed: 81.01 min\n", "Epoch: 085/150 | Batch 0000/0092 | Cost: 3.4316\n", "Time elapsed: 81.97 min\n", "Epoch: 086/150 | Batch 0000/0092 | Cost: 3.4448\n", "Time elapsed: 82.94 min\n", "Epoch: 087/150 | Batch 0000/0092 | Cost: 3.3413\n", "Time elapsed: 83.89 min\n", "Epoch: 088/150 | Batch 0000/0092 | Cost: 3.4418\n", "Time elapsed: 84.86 min\n", "Epoch: 089/150 | Batch 0000/0092 | Cost: 3.4258\n", "Time elapsed: 85.82 min\n", "Epoch: 090/150 | Batch 0000/0092 | Cost: 3.3049\n", "Time elapsed: 86.78 min\n", "Epoch: 091/150 | Batch 0000/0092 | Cost: 3.2554\n", "Time elapsed: 87.73 min\n", "Epoch: 092/150 | Batch 0000/0092 | Cost: 3.2919\n", "Time elapsed: 88.69 min\n", "Epoch: 093/150 | Batch 0000/0092 | Cost: 3.3172\n", "Time elapsed: 89.65 min\n", "Epoch: 094/150 | Batch 0000/0092 | Cost: 3.5744\n", "Time elapsed: 90.62 min\n", "Epoch: 095/150 | Batch 0000/0092 | Cost: 4.5396\n", "Time elapsed: 91.58 min\n", "Epoch: 096/150 | Batch 0000/0092 | Cost: 3.7548\n", "Time elapsed: 92.54 min\n", "Epoch: 097/150 | Batch 0000/0092 | Cost: 3.4449\n", "Time elapsed: 93.49 min\n", "Epoch: 098/150 | Batch 0000/0092 | Cost: 3.3186\n", "Time elapsed: 94.46 min\n", "Epoch: 099/150 | Batch 0000/0092 | Cost: 3.2050\n", "Time elapsed: 95.42 min\n", "Epoch: 100/150 | Batch 0000/0092 | Cost: 3.1218\n", "Time elapsed: 96.38 min\n", "Epoch: 101/150 | Batch 0000/0092 | Cost: 3.0612\n", "Time elapsed: 97.35 min\n", "Epoch: 102/150 | Batch 0000/0092 | Cost: 3.0640\n", "Time elapsed: 98.31 min\n", "Epoch: 103/150 | Batch 0000/0092 | Cost: 2.8820\n", "Time elapsed: 99.27 min\n", "Epoch: 104/150 | Batch 0000/0092 | Cost: 2.9511\n", "Time elapsed: 100.23 min\n", "Epoch: 105/150 | Batch 0000/0092 | Cost: 2.9219\n", "Time elapsed: 101.19 min\n", "Epoch: 106/150 | Batch 0000/0092 | Cost: 2.9429\n", "Time elapsed: 102.15 min\n", "Epoch: 107/150 | Batch 0000/0092 | Cost: 2.8934\n", "Time elapsed: 103.11 min\n", "Epoch: 108/150 | Batch 0000/0092 | Cost: 2.8541\n", "Time elapsed: 104.06 min\n", "Epoch: 109/150 | Batch 0000/0092 | Cost: 2.8962\n", "Time elapsed: 105.03 min\n", "Epoch: 110/150 | Batch 0000/0092 | Cost: 2.8225\n", "Time elapsed: 105.99 min\n", "Epoch: 111/150 | Batch 0000/0092 | Cost: 2.7968\n", "Time elapsed: 106.96 min\n", "Epoch: 112/150 | Batch 0000/0092 | Cost: 2.7319\n", "Time elapsed: 107.93 min\n", "Epoch: 113/150 | Batch 0000/0092 | Cost: 2.6711\n", "Time elapsed: 108.89 min\n", "Epoch: 114/150 | Batch 0000/0092 | Cost: 2.8028\n", "Time elapsed: 109.86 min\n", "Epoch: 115/150 | Batch 0000/0092 | Cost: 2.7948\n", "Time elapsed: 110.83 min\n", "Epoch: 116/150 | Batch 0000/0092 | Cost: 2.7675\n", "Time elapsed: 111.79 min\n", "Epoch: 117/150 | Batch 0000/0092 | Cost: 2.8945\n", "Time elapsed: 112.75 min\n", "Epoch: 118/150 | Batch 0000/0092 | Cost: 4.3488\n", "Time elapsed: 113.71 min\n", "Epoch: 119/150 | Batch 0000/0092 | Cost: 3.8014\n", "Time elapsed: 114.67 min\n", "Epoch: 120/150 | Batch 0000/0092 | Cost: 3.3284\n", "Time elapsed: 115.64 min\n", "Epoch: 121/150 | Batch 0000/0092 | Cost: 2.9553\n", "Time elapsed: 116.59 min\n", "Epoch: 122/150 | Batch 0000/0092 | Cost: 2.8341\n", "Time elapsed: 117.56 min\n", "Epoch: 123/150 | Batch 0000/0092 | Cost: 2.6916\n", "Time elapsed: 118.52 min\n", "Epoch: 124/150 | Batch 0000/0092 | Cost: 2.6589\n", "Time elapsed: 119.48 min\n", "Epoch: 125/150 | Batch 0000/0092 | Cost: 2.6671\n", "Time elapsed: 120.45 min\n", "Epoch: 126/150 | Batch 0000/0092 | Cost: 2.5647\n", "Time elapsed: 121.41 min\n", "Epoch: 127/150 | Batch 0000/0092 | Cost: 2.5726\n", "Time elapsed: 122.38 min\n", "Epoch: 128/150 | Batch 0000/0092 | Cost: 2.5466\n", "Time elapsed: 123.34 min\n", "Epoch: 129/150 | Batch 0000/0092 | Cost: 2.4779\n", "Time elapsed: 124.30 min\n", "Epoch: 130/150 | Batch 0000/0092 | Cost: 2.5617\n", "Time elapsed: 125.26 min\n", "Epoch: 131/150 | Batch 0000/0092 | Cost: 2.4226\n", "Time elapsed: 126.22 min\n", "Epoch: 132/150 | Batch 0000/0092 | Cost: 2.3563\n", "Time elapsed: 127.18 min\n", "Epoch: 133/150 | Batch 0000/0092 | Cost: 2.4614\n", "Time elapsed: 128.14 min\n", "Epoch: 134/150 | Batch 0000/0092 | Cost: 2.3724\n", "Time elapsed: 129.10 min\n", "Epoch: 135/150 | Batch 0000/0092 | Cost: 2.5513\n", "Time elapsed: 130.07 min\n", "Epoch: 136/150 | Batch 0000/0092 | Cost: 3.1409\n", "Time elapsed: 131.02 min\n", "Epoch: 137/150 | Batch 0000/0092 | Cost: 3.1343\n", "Time elapsed: 131.98 min\n", "Epoch: 138/150 | Batch 0000/0092 | Cost: 3.0905\n", "Time elapsed: 132.94 min\n", "Epoch: 139/150 | Batch 0000/0092 | Cost: 2.8391\n", "Time elapsed: 133.90 min\n", "Epoch: 140/150 | Batch 0000/0092 | Cost: 2.6408\n", "Time elapsed: 134.86 min\n", "Epoch: 141/150 | Batch 0000/0092 | Cost: 2.4640\n", "Time elapsed: 135.83 min\n", "Epoch: 142/150 | Batch 0000/0092 | Cost: 2.4268\n", "Time elapsed: 136.79 min\n", "Epoch: 143/150 | Batch 0000/0092 | Cost: 2.4114\n", "Time elapsed: 137.75 min\n", "Epoch: 144/150 | Batch 0000/0092 | Cost: 2.3011\n", "Time elapsed: 138.71 min\n", "Epoch: 145/150 | Batch 0000/0092 | Cost: 2.2850\n", "Time elapsed: 139.67 min\n", "Epoch: 146/150 | Batch 0000/0092 | Cost: 2.3117\n", "Time elapsed: 140.63 min\n", "Epoch: 147/150 | Batch 0000/0092 | Cost: 2.3350\n", "Time elapsed: 141.58 min\n", "Epoch: 148/150 | Batch 0000/0092 | Cost: 2.1746\n", "Time elapsed: 142.54 min\n", "Epoch: 149/150 | Batch 0000/0092 | Cost: 2.3144\n", "Time elapsed: 143.49 min\n", "Epoch: 150/150 | Batch 0000/0092 | Cost: 2.2799\n", "Time elapsed: 144.45 min\n" ] } ], "source": [ "def compute_mae_and_mse(model, data_loader, device):\n", " mae, mse, num_examples = 0, 0, 0\n", " for i, (features, targets, levels) in enumerate(data_loader):\n", "\n", " features = features.to(device)\n", " targets = targets.to(device)\n", "\n", " logits, probas = model(features)\n", " predict_levels = probas > 0.5\n", " predicted_labels = torch.sum(predict_levels, dim=1)\n", " num_examples += targets.size(0)\n", " mae += torch.sum(torch.abs(predicted_labels - targets))\n", " mse += torch.sum((predicted_labels - targets)**2)\n", " mae = mae.float() / num_examples\n", " mse = mse.float() / num_examples\n", " return mae, mse\n", "\n", "\n", "start_time = time.time()\n", "for epoch in range(NUM_EPOCHS):\n", "\n", " model.train()\n", " for batch_idx, (features, targets, levels) in enumerate(train_loader):\n", "\n", " features = features.to(DEVICE)\n", " targets = targets\n", " targets = targets.to(DEVICE)\n", " levels = levels.to(DEVICE)\n", "\n", " # FORWARD AND BACK PROP\n", " logits, probas = model(features)\n", " cost = cost_fn(logits, levels)\n", " optimizer.zero_grad()\n", "\n", " cost.backward()\n", "\n", " # UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", "\n", " # LOGGING\n", " if not batch_idx % 150:\n", " s = ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'\n", " % (epoch+1, NUM_EPOCHS, batch_idx,\n", " len(train_dataset)//BATCH_SIZE, cost))\n", " print(s)\n", "\n", " s = 'Time elapsed: %.2f min' % ((time.time() - start_time)/60)\n", " print(s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MAE/RMSE: | Train: 0.55/0.88 | Test: 3.38/4.71\n", "Total Training Time: 145.23 min\n" ] } ], "source": [ "model.eval()\n", "with torch.set_grad_enabled(False): # save memory during inference\n", "\n", " train_mae, train_mse = compute_mae_and_mse(model, train_loader,\n", " device=DEVICE)\n", " test_mae, test_mse = compute_mae_and_mse(model, test_loader,\n", " device=DEVICE)\n", "\n", " s = 'MAE/RMSE: | Train: %.2f/%.2f | Test: %.2f/%.2f' % (\n", " train_mae, torch.sqrt(train_mse), test_mae, torch.sqrt(test_mse))\n", " print(s)\n", "\n", "s = 'Total Training Time: %.2f min' % ((time.time() - start_time)/60)\n", "print(s)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "pandas 0.23.4\n", "torch 1.1.0\n", "PIL.Image 5.3.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }