{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Ordinal Regression CNN -- Niu et al. 2016"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of a method for ordinal regression by Niu et al. [1] applied to predicting age from face images in the AFAD [1] (Asian Face) dataset using a simple ResNet34 [2] convolutional network architecture.\n",
    "\n",
    "Note that in order to reduce training time, only a subset of AFAD (AFAD-Lite) is being used.\n",
    "\n",
    "- [1] Niu, Zhenxing, Mo Zhou, Le Wang, Xinbo Gao, and Gang Hua. \"[Ordinal regression with multiple output cnn for age estimation](https://ieeexplore.ieee.org/document/7780901/).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4920-4928. 2016.\n",
    "- [2] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \"[Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'tarball-lite'...\n",
      "remote: Enumerating objects: 37, done.\u001b[K\n",
      "remote: Total 37 (delta 0), reused 0 (delta 0), pack-reused 37\u001b[K\n",
      "Unpacking objects: 100% (37/37), done.\n",
      "Checking out files: 100% (30/30), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/afad-dataset/tarball-lite.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat tarball-lite/AFAD-Lite.tar.xz* > tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar xf tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rootDir = 'AFAD-Lite'\n",
    "\n",
    "files = [os.path.relpath(os.path.join(dirpath, file), rootDir)\n",
    "         for (dirpath, dirnames, filenames) in os.walk(rootDir) \n",
    "         for file in filenames if file.endswith('.jpg')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "59344"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = {}\n",
    "\n",
    "d['age'] = []\n",
    "d['gender'] = []\n",
    "d['file'] = []\n",
    "d['path'] = []\n",
    "\n",
    "for f in files:\n",
    "    age, gender, fname = f.split('/')\n",
    "    if gender == '111':\n",
    "        gender = 'male'\n",
    "    else:\n",
    "        gender = 'female'\n",
    "        \n",
    "    d['age'].append(age)\n",
    "    d['gender'].append(gender)\n",
    "    d['file'].append(fname)\n",
    "    d['path'].append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>file</th>\n",
       "      <th>path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>474596-0.jpg</td>\n",
       "      <td>39/112/474596-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>397477-0.jpg</td>\n",
       "      <td>39/112/397477-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>576466-0.jpg</td>\n",
       "      <td>39/112/576466-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>399405-0.jpg</td>\n",
       "      <td>39/112/399405-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>410524-0.jpg</td>\n",
       "      <td>39/112/410524-0.jpg</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  age  gender          file                 path\n",
       "0  39  female  474596-0.jpg  39/112/474596-0.jpg\n",
       "1  39  female  397477-0.jpg  39/112/397477-0.jpg\n",
       "2  39  female  576466-0.jpg  39/112/576466-0.jpg\n",
       "3  39  female  399405-0.jpg  39/112/399405-0.jpg\n",
       "4  39  female  410524-0.jpg  39/112/410524-0.jpg"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(d)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'18'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['age'].min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['age'] = df['age'].values.astype(int) - 18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "msk = np.random.rand(len(df)) < 0.8\n",
    "df_train = df[msk]\n",
    "df_test = df[~msk]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.set_index('file', inplace=True)\n",
    "df_train.to_csv('training_set_lite.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test.set_index('file', inplace=True)\n",
    "df_test.to_csv('test_set_lite.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22\n"
     ]
    }
   ],
   "source": [
    "num_ages = np.unique(df['age'].values).shape[0]\n",
    "print(num_ages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "NUM_WORKERS = 8\n",
    "\n",
    "NUM_CLASSES = 22\n",
    "BATCH_SIZE = 512\n",
    "NUM_EPOCHS = 150\n",
    "LEARNING_RATE = 0.0005\n",
    "RANDOM_SEED = 123\n",
    "\n",
    "TRAIN_CSV_PATH = 'training_set_lite.csv'\n",
    "TEST_CSV_PATH = 'test_set_lite.csv'\n",
    "IMAGE_PATH = 'AFAD-Lite'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AFADDatasetAge(Dataset):\n",
    "    \"\"\"Custom Dataset for loading AFAD face images\"\"\"\n",
    "\n",
    "    def __init__(self, csv_path, img_dir, transform=None):\n",
    "\n",
    "        df = pd.read_csv(csv_path, index_col=0)\n",
    "        self.img_dir = img_dir\n",
    "        self.csv_path = csv_path\n",
    "        self.img_paths = df['path']\n",
    "        self.y = df['age'].values\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.img_paths[index]))\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        label = self.y[index]\n",
    "        levels = [1]*label + [0]*(NUM_CLASSES - 1 - label)\n",
    "        levels = torch.tensor(levels, dtype=torch.float32)\n",
    "\n",
    "        return img, label, levels\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]\n",
    "\n",
    "\n",
    "custom_transform = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                       transforms.RandomCrop((120, 120)),\n",
    "                                       transforms.ToTensor()])\n",
    "\n",
    "train_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=custom_transform)\n",
    "\n",
    "\n",
    "custom_transform2 = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                        transforms.CenterCrop((120, 120)),\n",
    "                                        transforms.ToTensor()])\n",
    "\n",
    "test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH,\n",
    "                              img_dir=IMAGE_PATH,\n",
    "                              transform=custom_transform2)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True,\n",
    "                          num_workers=NUM_WORKERS)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         shuffle=False,\n",
    "                         num_workers=NUM_WORKERS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "# MODEL\n",
    "##########################\n",
    "\n",
    "class AlexNet(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.features = nn.Sequential(\n",
    "            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        )\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(256 * 6 * 6, 4096),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "\n",
    "        self.fc = nn.Linear(4096, (self.num_classes-1)*2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), 256 * 6 * 6)\n",
    "        x = self.classifier(x)\n",
    "\n",
    "        logits = self.fc(x)\n",
    "        logits = logits.view(-1, (self.num_classes-1), 2)\n",
    "        probas = F.softmax(logits, dim=2)[:, :, 1]\n",
    "        return logits, probas\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################\n",
    "# Initialize Cost, Model, and Optimizer\n",
    "###########################################\n",
    "\n",
    "def cost_fn(logits, levels):\n",
    "    val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1]*levels\n",
    "                      + F.log_softmax(logits, dim=2)[:, :, 0]*(1-levels)), dim=1))\n",
    "    return torch.mean(val)\n",
    "\n",
    "\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "torch.cuda.manual_seed(RANDOM_SEED)\n",
    "model = AlexNet(NUM_CLASSES)\n",
    "\n",
    "model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/150 | Batch 0000/0092 | Cost: 14.5346\n",
      "Time elapsed: 0.47 min\n",
      "Epoch: 002/150 | Batch 0000/0092 | Cost: 10.1990\n",
      "Time elapsed: 0.93 min\n",
      "Epoch: 003/150 | Batch 0000/0092 | Cost: 9.3695\n",
      "Time elapsed: 1.41 min\n",
      "Epoch: 004/150 | Batch 0000/0092 | Cost: 8.7663\n",
      "Time elapsed: 1.88 min\n",
      "Epoch: 005/150 | Batch 0000/0092 | Cost: 8.5821\n",
      "Time elapsed: 2.36 min\n",
      "Epoch: 006/150 | Batch 0000/0092 | Cost: 8.8506\n",
      "Time elapsed: 2.80 min\n",
      "Epoch: 007/150 | Batch 0000/0092 | Cost: 8.1522\n",
      "Time elapsed: 3.25 min\n",
      "Epoch: 008/150 | Batch 0000/0092 | Cost: 9.3045\n",
      "Time elapsed: 3.74 min\n",
      "Epoch: 009/150 | Batch 0000/0092 | Cost: 8.2951\n",
      "Time elapsed: 4.21 min\n",
      "Epoch: 010/150 | Batch 0000/0092 | Cost: 8.1094\n",
      "Time elapsed: 4.67 min\n",
      "Epoch: 011/150 | Batch 0000/0092 | Cost: 8.3870\n",
      "Time elapsed: 5.15 min\n",
      "Epoch: 012/150 | Batch 0000/0092 | Cost: 8.1078\n",
      "Time elapsed: 5.62 min\n",
      "Epoch: 013/150 | Batch 0000/0092 | Cost: 7.6846\n",
      "Time elapsed: 6.11 min\n",
      "Epoch: 014/150 | Batch 0000/0092 | Cost: 7.7015\n",
      "Time elapsed: 6.56 min\n",
      "Epoch: 015/150 | Batch 0000/0092 | Cost: 7.5693\n",
      "Time elapsed: 7.02 min\n",
      "Epoch: 016/150 | Batch 0000/0092 | Cost: 7.9339\n",
      "Time elapsed: 7.51 min\n",
      "Epoch: 017/150 | Batch 0000/0092 | Cost: 7.7916\n",
      "Time elapsed: 7.97 min\n",
      "Epoch: 018/150 | Batch 0000/0092 | Cost: 7.1257\n",
      "Time elapsed: 8.46 min\n",
      "Epoch: 019/150 | Batch 0000/0092 | Cost: 6.8873\n",
      "Time elapsed: 8.89 min\n",
      "Epoch: 020/150 | Batch 0000/0092 | Cost: 7.1145\n",
      "Time elapsed: 9.36 min\n",
      "Epoch: 021/150 | Batch 0000/0092 | Cost: 7.3858\n",
      "Time elapsed: 9.78 min\n",
      "Epoch: 022/150 | Batch 0000/0092 | Cost: 7.5949\n",
      "Time elapsed: 10.25 min\n",
      "Epoch: 023/150 | Batch 0000/0092 | Cost: 6.9696\n",
      "Time elapsed: 10.68 min\n",
      "Epoch: 024/150 | Batch 0000/0092 | Cost: 7.1473\n",
      "Time elapsed: 11.15 min\n",
      "Epoch: 025/150 | Batch 0000/0092 | Cost: 6.6147\n",
      "Time elapsed: 11.61 min\n",
      "Epoch: 026/150 | Batch 0000/0092 | Cost: 6.3201\n",
      "Time elapsed: 12.12 min\n",
      "Epoch: 027/150 | Batch 0000/0092 | Cost: 6.5789\n",
      "Time elapsed: 12.53 min\n",
      "Epoch: 028/150 | Batch 0000/0092 | Cost: 6.4532\n",
      "Time elapsed: 12.99 min\n",
      "Epoch: 029/150 | Batch 0000/0092 | Cost: 6.5590\n",
      "Time elapsed: 13.44 min\n",
      "Epoch: 030/150 | Batch 0000/0092 | Cost: 6.4204\n",
      "Time elapsed: 13.91 min\n",
      "Epoch: 031/150 | Batch 0000/0092 | Cost: 6.1485\n",
      "Time elapsed: 14.37 min\n",
      "Epoch: 032/150 | Batch 0000/0092 | Cost: 6.6225\n",
      "Time elapsed: 14.87 min\n",
      "Epoch: 033/150 | Batch 0000/0092 | Cost: 5.9400\n",
      "Time elapsed: 15.32 min\n",
      "Epoch: 034/150 | Batch 0000/0092 | Cost: 6.0526\n",
      "Time elapsed: 15.82 min\n",
      "Epoch: 035/150 | Batch 0000/0092 | Cost: 5.9100\n",
      "Time elapsed: 16.28 min\n",
      "Epoch: 036/150 | Batch 0000/0092 | Cost: 5.8563\n",
      "Time elapsed: 16.75 min\n",
      "Epoch: 037/150 | Batch 0000/0092 | Cost: 5.4942\n",
      "Time elapsed: 17.20 min\n",
      "Epoch: 038/150 | Batch 0000/0092 | Cost: 5.3825\n",
      "Time elapsed: 17.69 min\n",
      "Epoch: 039/150 | Batch 0000/0092 | Cost: 5.4557\n",
      "Time elapsed: 18.15 min\n",
      "Epoch: 040/150 | Batch 0000/0092 | Cost: 5.4534\n",
      "Time elapsed: 18.66 min\n",
      "Epoch: 041/150 | Batch 0000/0092 | Cost: 5.2443\n",
      "Time elapsed: 19.08 min\n",
      "Epoch: 042/150 | Batch 0000/0092 | Cost: 5.2351\n",
      "Time elapsed: 19.57 min\n",
      "Epoch: 043/150 | Batch 0000/0092 | Cost: 5.1354\n",
      "Time elapsed: 20.02 min\n",
      "Epoch: 044/150 | Batch 0000/0092 | Cost: 5.1245\n",
      "Time elapsed: 20.50 min\n",
      "Epoch: 045/150 | Batch 0000/0092 | Cost: 5.0352\n",
      "Time elapsed: 20.96 min\n",
      "Epoch: 046/150 | Batch 0000/0092 | Cost: 4.7361\n",
      "Time elapsed: 21.47 min\n",
      "Epoch: 047/150 | Batch 0000/0092 | Cost: 4.6973\n",
      "Time elapsed: 21.93 min\n",
      "Epoch: 048/150 | Batch 0000/0092 | Cost: 4.6416\n",
      "Time elapsed: 22.44 min\n",
      "Epoch: 049/150 | Batch 0000/0092 | Cost: 4.6076\n",
      "Time elapsed: 22.89 min\n",
      "Epoch: 050/150 | Batch 0000/0092 | Cost: 4.5119\n",
      "Time elapsed: 23.37 min\n",
      "Epoch: 051/150 | Batch 0000/0092 | Cost: 4.2692\n",
      "Time elapsed: 23.82 min\n",
      "Epoch: 052/150 | Batch 0000/0092 | Cost: 4.2506\n",
      "Time elapsed: 24.33 min\n",
      "Epoch: 053/150 | Batch 0000/0092 | Cost: 4.2682\n",
      "Time elapsed: 24.79 min\n",
      "Epoch: 054/150 | Batch 0000/0092 | Cost: 4.7041\n",
      "Time elapsed: 25.30 min\n",
      "Epoch: 055/150 | Batch 0000/0092 | Cost: 3.9781\n",
      "Time elapsed: 25.75 min\n",
      "Epoch: 056/150 | Batch 0000/0092 | Cost: 4.4825\n",
      "Time elapsed: 26.23 min\n",
      "Epoch: 057/150 | Batch 0000/0092 | Cost: 3.8956\n",
      "Time elapsed: 26.70 min\n",
      "Epoch: 058/150 | Batch 0000/0092 | Cost: 3.8620\n",
      "Time elapsed: 27.17 min\n",
      "Epoch: 059/150 | Batch 0000/0092 | Cost: 4.3340\n",
      "Time elapsed: 27.64 min\n",
      "Epoch: 060/150 | Batch 0000/0092 | Cost: 3.6278\n",
      "Time elapsed: 28.13 min\n",
      "Epoch: 061/150 | Batch 0000/0092 | Cost: 3.8466\n",
      "Time elapsed: 28.61 min\n",
      "Epoch: 062/150 | Batch 0000/0092 | Cost: 4.0128\n",
      "Time elapsed: 29.06 min\n",
      "Epoch: 063/150 | Batch 0000/0092 | Cost: 3.6341\n",
      "Time elapsed: 29.54 min\n",
      "Epoch: 064/150 | Batch 0000/0092 | Cost: 3.7518\n",
      "Time elapsed: 29.99 min\n",
      "Epoch: 065/150 | Batch 0000/0092 | Cost: 3.4808\n",
      "Time elapsed: 30.47 min\n",
      "Epoch: 066/150 | Batch 0000/0092 | Cost: 3.8870\n",
      "Time elapsed: 30.94 min\n",
      "Epoch: 067/150 | Batch 0000/0092 | Cost: 3.3818\n",
      "Time elapsed: 31.42 min\n",
      "Epoch: 068/150 | Batch 0000/0092 | Cost: 3.2848\n",
      "Time elapsed: 31.90 min\n",
      "Epoch: 069/150 | Batch 0000/0092 | Cost: 3.3755\n",
      "Time elapsed: 32.40 min\n",
      "Epoch: 070/150 | Batch 0000/0092 | Cost: 3.3649\n",
      "Time elapsed: 32.87 min\n",
      "Epoch: 071/150 | Batch 0000/0092 | Cost: 3.3467\n",
      "Time elapsed: 33.35 min\n",
      "Epoch: 072/150 | Batch 0000/0092 | Cost: 2.9973\n",
      "Time elapsed: 33.82 min\n",
      "Epoch: 073/150 | Batch 0000/0092 | Cost: 3.0707\n",
      "Time elapsed: 34.29 min\n",
      "Epoch: 074/150 | Batch 0000/0092 | Cost: 3.3438\n",
      "Time elapsed: 34.76 min\n",
      "Epoch: 075/150 | Batch 0000/0092 | Cost: 3.0139\n",
      "Time elapsed: 35.23 min\n",
      "Epoch: 076/150 | Batch 0000/0092 | Cost: 3.0865\n",
      "Time elapsed: 35.70 min\n",
      "Epoch: 077/150 | Batch 0000/0092 | Cost: 3.1229\n",
      "Time elapsed: 36.16 min\n",
      "Epoch: 078/150 | Batch 0000/0092 | Cost: 2.9919\n",
      "Time elapsed: 36.61 min\n",
      "Epoch: 079/150 | Batch 0000/0092 | Cost: 3.0086\n",
      "Time elapsed: 37.09 min\n",
      "Epoch: 080/150 | Batch 0000/0092 | Cost: 2.8564\n",
      "Time elapsed: 37.55 min\n",
      "Epoch: 081/150 | Batch 0000/0092 | Cost: 2.9466\n",
      "Time elapsed: 38.00 min\n",
      "Epoch: 082/150 | Batch 0000/0092 | Cost: 2.7265\n",
      "Time elapsed: 38.45 min\n",
      "Epoch: 083/150 | Batch 0000/0092 | Cost: 2.8810\n",
      "Time elapsed: 38.91 min\n",
      "Epoch: 084/150 | Batch 0000/0092 | Cost: 2.7061\n",
      "Time elapsed: 39.37 min\n",
      "Epoch: 085/150 | Batch 0000/0092 | Cost: 2.6701\n",
      "Time elapsed: 39.87 min\n",
      "Epoch: 086/150 | Batch 0000/0092 | Cost: 2.6754\n",
      "Time elapsed: 40.30 min\n",
      "Epoch: 087/150 | Batch 0000/0092 | Cost: 2.6844\n",
      "Time elapsed: 40.79 min\n",
      "Epoch: 088/150 | Batch 0000/0092 | Cost: 2.6610\n",
      "Time elapsed: 41.23 min\n",
      "Epoch: 089/150 | Batch 0000/0092 | Cost: 2.9059\n",
      "Time elapsed: 41.72 min\n",
      "Epoch: 090/150 | Batch 0000/0092 | Cost: 2.6932\n",
      "Time elapsed: 42.16 min\n",
      "Epoch: 091/150 | Batch 0000/0092 | Cost: 2.4559\n",
      "Time elapsed: 42.64 min\n",
      "Epoch: 092/150 | Batch 0000/0092 | Cost: 2.6640\n",
      "Time elapsed: 43.11 min\n",
      "Epoch: 093/150 | Batch 0000/0092 | Cost: 2.5860\n",
      "Time elapsed: 43.59 min\n",
      "Epoch: 094/150 | Batch 0000/0092 | Cost: 2.6070\n",
      "Time elapsed: 44.06 min\n",
      "Epoch: 095/150 | Batch 0000/0092 | Cost: 2.4515\n",
      "Time elapsed: 44.52 min\n",
      "Epoch: 096/150 | Batch 0000/0092 | Cost: 2.5023\n",
      "Time elapsed: 44.97 min\n",
      "Epoch: 097/150 | Batch 0000/0092 | Cost: 2.3886\n",
      "Time elapsed: 45.45 min\n",
      "Epoch: 098/150 | Batch 0000/0092 | Cost: 2.6237\n",
      "Time elapsed: 45.90 min\n",
      "Epoch: 099/150 | Batch 0000/0092 | Cost: 2.3156\n",
      "Time elapsed: 46.37 min\n",
      "Epoch: 100/150 | Batch 0000/0092 | Cost: 2.2186\n",
      "Time elapsed: 46.85 min\n",
      "Epoch: 101/150 | Batch 0000/0092 | Cost: 2.4205\n",
      "Time elapsed: 47.32 min\n",
      "Epoch: 102/150 | Batch 0000/0092 | Cost: 2.4243\n",
      "Time elapsed: 47.79 min\n",
      "Epoch: 103/150 | Batch 0000/0092 | Cost: 2.4262\n",
      "Time elapsed: 48.23 min\n",
      "Epoch: 104/150 | Batch 0000/0092 | Cost: 2.4243\n",
      "Time elapsed: 48.69 min\n",
      "Epoch: 105/150 | Batch 0000/0092 | Cost: 2.1756\n",
      "Time elapsed: 49.15 min\n",
      "Epoch: 106/150 | Batch 0000/0092 | Cost: 2.1816\n",
      "Time elapsed: 49.61 min\n",
      "Epoch: 107/150 | Batch 0000/0092 | Cost: 2.3446\n",
      "Time elapsed: 50.07 min\n",
      "Epoch: 108/150 | Batch 0000/0092 | Cost: 2.2174\n",
      "Time elapsed: 50.51 min\n",
      "Epoch: 109/150 | Batch 0000/0092 | Cost: 2.2063\n",
      "Time elapsed: 50.99 min\n",
      "Epoch: 110/150 | Batch 0000/0092 | Cost: 2.3621\n",
      "Time elapsed: 51.47 min\n",
      "Epoch: 111/150 | Batch 0000/0092 | Cost: 2.2048\n",
      "Time elapsed: 51.93 min\n",
      "Epoch: 112/150 | Batch 0000/0092 | Cost: 1.9002\n",
      "Time elapsed: 52.39 min\n",
      "Epoch: 113/150 | Batch 0000/0092 | Cost: 2.3146\n",
      "Time elapsed: 52.88 min\n",
      "Epoch: 114/150 | Batch 0000/0092 | Cost: 2.2218\n",
      "Time elapsed: 53.32 min\n",
      "Epoch: 115/150 | Batch 0000/0092 | Cost: 2.5772\n",
      "Time elapsed: 53.79 min\n",
      "Epoch: 116/150 | Batch 0000/0092 | Cost: 1.9954\n",
      "Time elapsed: 54.22 min\n",
      "Epoch: 117/150 | Batch 0000/0092 | Cost: 2.2189\n",
      "Time elapsed: 54.72 min\n",
      "Epoch: 118/150 | Batch 0000/0092 | Cost: 2.0534\n",
      "Time elapsed: 55.17 min\n",
      "Epoch: 119/150 | Batch 0000/0092 | Cost: 2.1909\n",
      "Time elapsed: 55.67 min\n",
      "Epoch: 120/150 | Batch 0000/0092 | Cost: 1.9588\n",
      "Time elapsed: 56.10 min\n",
      "Epoch: 121/150 | Batch 0000/0092 | Cost: 1.8609\n",
      "Time elapsed: 56.57 min\n",
      "Epoch: 122/150 | Batch 0000/0092 | Cost: 2.2825\n",
      "Time elapsed: 57.04 min\n",
      "Epoch: 123/150 | Batch 0000/0092 | Cost: 2.2175\n",
      "Time elapsed: 57.55 min\n",
      "Epoch: 124/150 | Batch 0000/0092 | Cost: 2.0279\n",
      "Time elapsed: 57.98 min\n",
      "Epoch: 125/150 | Batch 0000/0092 | Cost: 1.9375\n",
      "Time elapsed: 58.46 min\n",
      "Epoch: 126/150 | Batch 0000/0092 | Cost: 2.0164\n",
      "Time elapsed: 58.89 min\n",
      "Epoch: 127/150 | Batch 0000/0092 | Cost: 2.1515\n",
      "Time elapsed: 59.39 min\n",
      "Epoch: 128/150 | Batch 0000/0092 | Cost: 1.9873\n",
      "Time elapsed: 59.83 min\n",
      "Epoch: 129/150 | Batch 0000/0092 | Cost: 1.8686\n",
      "Time elapsed: 60.30 min\n",
      "Epoch: 130/150 | Batch 0000/0092 | Cost: 1.9796\n",
      "Time elapsed: 60.73 min\n",
      "Epoch: 131/150 | Batch 0000/0092 | Cost: 1.7672\n",
      "Time elapsed: 61.23 min\n",
      "Epoch: 132/150 | Batch 0000/0092 | Cost: 1.9022\n",
      "Time elapsed: 61.70 min\n",
      "Epoch: 133/150 | Batch 0000/0092 | Cost: 1.8617\n",
      "Time elapsed: 62.19 min\n",
      "Epoch: 134/150 | Batch 0000/0092 | Cost: 1.7341\n",
      "Time elapsed: 62.66 min\n",
      "Epoch: 135/150 | Batch 0000/0092 | Cost: 1.7973\n",
      "Time elapsed: 63.17 min\n",
      "Epoch: 136/150 | Batch 0000/0092 | Cost: 1.7751\n",
      "Time elapsed: 63.63 min\n",
      "Epoch: 137/150 | Batch 0000/0092 | Cost: 1.9271\n",
      "Time elapsed: 64.14 min\n",
      "Epoch: 138/150 | Batch 0000/0092 | Cost: 1.6380\n",
      "Time elapsed: 64.59 min\n",
      "Epoch: 139/150 | Batch 0000/0092 | Cost: 1.7169\n",
      "Time elapsed: 65.07 min\n",
      "Epoch: 140/150 | Batch 0000/0092 | Cost: 1.8063\n",
      "Time elapsed: 65.53 min\n",
      "Epoch: 141/150 | Batch 0000/0092 | Cost: 1.8708\n",
      "Time elapsed: 66.03 min\n",
      "Epoch: 142/150 | Batch 0000/0092 | Cost: 1.4449\n",
      "Time elapsed: 66.48 min\n",
      "Epoch: 143/150 | Batch 0000/0092 | Cost: 1.8047\n",
      "Time elapsed: 66.95 min\n",
      "Epoch: 144/150 | Batch 0000/0092 | Cost: 1.9332\n",
      "Time elapsed: 67.40 min\n",
      "Epoch: 145/150 | Batch 0000/0092 | Cost: 1.8951\n",
      "Time elapsed: 67.85 min\n",
      "Epoch: 146/150 | Batch 0000/0092 | Cost: 1.6895\n",
      "Time elapsed: 68.33 min\n",
      "Epoch: 147/150 | Batch 0000/0092 | Cost: 1.7324\n",
      "Time elapsed: 68.77 min\n",
      "Epoch: 148/150 | Batch 0000/0092 | Cost: 1.6677\n",
      "Time elapsed: 69.24 min\n",
      "Epoch: 149/150 | Batch 0000/0092 | Cost: 1.6663\n",
      "Time elapsed: 69.71 min\n",
      "Epoch: 150/150 | Batch 0000/0092 | Cost: 1.7063\n",
      "Time elapsed: 70.18 min\n"
     ]
    }
   ],
   "source": [
    "def compute_mae_and_mse(model, data_loader, device):\n",
    "    mae, mse, num_examples = 0, 0, 0\n",
    "    for i, (features, targets, levels) in enumerate(data_loader):\n",
    "\n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        predict_levels = probas > 0.5\n",
    "        predicted_labels = torch.sum(predict_levels, dim=1)\n",
    "        num_examples += targets.size(0)\n",
    "        mae += torch.sum(torch.abs(predicted_labels - targets))\n",
    "        mse += torch.sum((predicted_labels - targets)**2)\n",
    "    mae = mae.float() / num_examples\n",
    "    mse = mse.float() / num_examples\n",
    "    return mae, mse\n",
    "\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets, levels) in enumerate(train_loader):\n",
    "\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets\n",
    "        targets = targets.to(DEVICE)\n",
    "        levels = levels.to(DEVICE)\n",
    "\n",
    "        # FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = cost_fn(logits, levels)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        cost.backward()\n",
    "\n",
    "        # UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "\n",
    "        # LOGGING\n",
    "        if not batch_idx % 150:\n",
    "            s = ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'\n",
    "                 % (epoch+1, NUM_EPOCHS, batch_idx,\n",
    "                     len(train_dataset)//BATCH_SIZE, cost))\n",
    "            print(s)\n",
    "\n",
    "    s = 'Time elapsed: %.2f min' % ((time.time() - start_time)/60)\n",
    "    print(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE/RMSE: | Train: 0.65/1.13 | Test: 3.91/5.40\n",
      "Total Training Time: 70.77 min\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.set_grad_enabled(False):  # save memory during inference\n",
    "\n",
    "    train_mae, train_mse = compute_mae_and_mse(model, train_loader,\n",
    "                                               device=DEVICE)\n",
    "    test_mae, test_mse = compute_mae_and_mse(model, test_loader,\n",
    "                                             device=DEVICE)\n",
    "\n",
    "    s = 'MAE/RMSE: | Train: %.2f/%.2f | Test: %.2f/%.2f' % (\n",
    "        train_mae, torch.sqrt(train_mse), test_mae, torch.sqrt(test_mse))\n",
    "    print(s)\n",
    "\n",
    "s = 'Total Training Time: %.2f min' % ((time.time() - start_time)/60)\n",
    "print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "pandas      0.23.4\n",
      "torch       1.1.0\n",
      "PIL.Image   5.3.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}