{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Ordinal Regression CNN -- Beckham and Pal 2016"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of a method for ordinal regression by Beckham and Pal [1] applied to predicting age from face images in the AFAD [2] (Asian Face) dataset using a simple ResNet-34 [3] convolutional network architecture.\n",
    "\n",
    "Note that in order to reduce training time, only a subset of AFAD (AFAD-Lite) is being used.\n",
    "\n",
    "- [1] Beckham, Christopher, and Christopher Pal. \"[A simple squared-error reformulation for ordinal classification](https://arxiv.org/abs/1612.00775).\" arXiv preprint arXiv:1612.00775 (2016).\n",
    "- [2] Niu, Zhenxing, Mo Zhou, Le Wang, Xinbo Gao, and Gang Hua. \"[Ordinal regression with multiple output cnn for age estimation](https://ieeexplore.ieee.org/document/7780901/).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4920-4928. 2016.\n",
    "- [3] He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. \"[Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html).\" In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'tarball-lite'...\n",
      "remote: Enumerating objects: 37, done.\u001b[K\n",
      "remote: Total 37 (delta 0), reused 0 (delta 0), pack-reused 37\u001b[K\n",
      "Unpacking objects: 100% (37/37), done.\n",
      "Checking out files: 100% (30/30), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/afad-dataset/tarball-lite.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat tarball-lite/AFAD-Lite.tar.xz* > tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar xf tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rootDir = 'AFAD-Lite'\n",
    "\n",
    "files = [os.path.relpath(os.path.join(dirpath, file), rootDir)\n",
    "         for (dirpath, dirnames, filenames) in os.walk(rootDir) \n",
    "         for file in filenames if file.endswith('.jpg')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "59344"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = {}\n",
    "\n",
    "d['age'] = []\n",
    "d['gender'] = []\n",
    "d['file'] = []\n",
    "d['path'] = []\n",
    "\n",
    "for f in files:\n",
    "    age, gender, fname = f.split('/')\n",
    "    if gender == '111':\n",
    "        gender = 'male'\n",
    "    else:\n",
    "        gender = 'female'\n",
    "        \n",
    "    d['age'].append(age)\n",
    "    d['gender'].append(gender)\n",
    "    d['file'].append(fname)\n",
    "    d['path'].append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>file</th>\n",
       "      <th>path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>474596-0.jpg</td>\n",
       "      <td>39/112/474596-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>397477-0.jpg</td>\n",
       "      <td>39/112/397477-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>576466-0.jpg</td>\n",
       "      <td>39/112/576466-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>399405-0.jpg</td>\n",
       "      <td>39/112/399405-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>410524-0.jpg</td>\n",
       "      <td>39/112/410524-0.jpg</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  age  gender          file                 path\n",
       "0  39  female  474596-0.jpg  39/112/474596-0.jpg\n",
       "1  39  female  397477-0.jpg  39/112/397477-0.jpg\n",
       "2  39  female  576466-0.jpg  39/112/576466-0.jpg\n",
       "3  39  female  399405-0.jpg  39/112/399405-0.jpg\n",
       "4  39  female  410524-0.jpg  39/112/410524-0.jpg"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(d)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'18'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['age'].min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['age'] = df['age'].values.astype(int) - 18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "msk = np.random.rand(len(df)) < 0.8\n",
    "df_train = df[msk]\n",
    "df_test = df[~msk]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.set_index('file', inplace=True)\n",
    "df_train.to_csv('training_set_lite.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test.set_index('file', inplace=True)\n",
    "df_test.to_csv('test_set_lite.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22\n"
     ]
    }
   ],
   "source": [
    "num_ages = np.unique(df['age'].values).shape[0]\n",
    "print(num_ages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "NUM_WORKERS = 4\n",
    "\n",
    "NUM_CLASSES = 22\n",
    "BATCH_SIZE = 512\n",
    "NUM_EPOCHS = 150\n",
    "LEARNING_RATE = 0.0005\n",
    "RANDOM_SEED = 123\n",
    "GRAYSCALE = False\n",
    "\n",
    "TRAIN_CSV_PATH = 'training_set_lite.csv'\n",
    "TEST_CSV_PATH = 'test_set_lite.csv'\n",
    "IMAGE_PATH = 'AFAD-Lite'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AFADDatasetAge(Dataset):\n",
    "    \"\"\"Custom Dataset for loading AFAD face images\"\"\"\n",
    "\n",
    "    def __init__(self, csv_path, img_dir, transform=None):\n",
    "\n",
    "        df = pd.read_csv(csv_path, index_col=0)\n",
    "        self.img_dir = img_dir\n",
    "        self.csv_path = csv_path\n",
    "        self.img_paths = df['path']\n",
    "        self.y = df['age'].values\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.img_paths[index]))\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        label = self.y[index]\n",
    "\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]\n",
    "\n",
    "\n",
    "custom_transform = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                       transforms.RandomCrop((120, 120)),\n",
    "                                       transforms.ToTensor()])\n",
    "\n",
    "train_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=custom_transform)\n",
    "\n",
    "\n",
    "custom_transform2 = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                        transforms.CenterCrop((120, 120)),\n",
    "                                        transforms.ToTensor()])\n",
    "\n",
    "test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH,\n",
    "                              img_dir=IMAGE_PATH,\n",
    "                              transform=custom_transform2)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True,\n",
    "                          num_workers=NUM_WORKERS)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         shuffle=False,\n",
    "                         num_workers=NUM_WORKERS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "# MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "def conv3x3(in_planes, out_planes, stride=1):\n",
    "    \"\"\"3x3 convolution with padding\"\"\"\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                     padding=1, bias=False)\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv3x3(planes, planes)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self, block, layers, num_classes, grayscale):\n",
    "        self.num_classes = num_classes\n",
    "        self.inplanes = 64\n",
    "        if grayscale:\n",
    "            in_dim = 1\n",
    "        else:\n",
    "            in_dim = 3\n",
    "        super(ResNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,\n",
    "                               bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AvgPool2d(7, stride=1, padding=2)\n",
    "        self.fc = nn.Linear(2048 * block.expansion, num_classes)\n",
    "        self.a = torch.nn.Parameter(torch.zeros(\n",
    "            self.num_classes).float().normal_(0.0, 0.1).view(-1, 1))\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
    "                m.weight.data.normal_(0, (2. / n)**.5)\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
    "                          kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(planes * block.expansion),\n",
    "            )\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.bn1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.layer4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        logits = self.fc(x)\n",
    "        probas = torch.softmax(logits, dim=1)\n",
    "        predictions = ((self.num_classes-1)\n",
    "                       * torch.sigmoid(probas.mm(self.a).view(-1)))\n",
    "        return logits, probas, predictions\n",
    "\n",
    "\n",
    "def resnet34(num_classes, grayscale):\n",
    "    \"\"\"Constructs a ResNet-34 model.\"\"\"\n",
    "    model = ResNet(block=BasicBlock,\n",
    "                   layers=[3, 4, 6, 3],\n",
    "                   num_classes=num_classes,\n",
    "                   grayscale=grayscale)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "###########################################\n",
    "# Initialize Cost, Model, and Optimizer\n",
    "###########################################\n",
    "\n",
    "def cost_fn(targets, predictions):\n",
    "    return torch.mean((targets.float() - predictions)**2)\n",
    "\n",
    "\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "torch.cuda.manual_seed(RANDOM_SEED)\n",
    "model = resnet34(NUM_CLASSES, GRAYSCALE)\n",
    "\n",
    "model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/150 | Batch 0000/0092 | Cost: 42.0424\n",
      "Time elapsed: 0.92 min\n",
      "Epoch: 002/150 | Batch 0000/0092 | Cost: 41.3301\n",
      "Time elapsed: 1.85 min\n",
      "Epoch: 003/150 | Batch 0000/0092 | Cost: 40.5070\n",
      "Time elapsed: 2.78 min\n",
      "Epoch: 004/150 | Batch 0000/0092 | Cost: 40.4149\n",
      "Time elapsed: 3.72 min\n",
      "Epoch: 005/150 | Batch 0000/0092 | Cost: 38.0820\n",
      "Time elapsed: 4.67 min\n",
      "Epoch: 006/150 | Batch 0000/0092 | Cost: 38.6630\n",
      "Time elapsed: 5.61 min\n",
      "Epoch: 007/150 | Batch 0000/0092 | Cost: 36.5432\n",
      "Time elapsed: 6.54 min\n",
      "Epoch: 008/150 | Batch 0000/0092 | Cost: 38.1368\n",
      "Time elapsed: 7.49 min\n",
      "Epoch: 009/150 | Batch 0000/0092 | Cost: 37.4299\n",
      "Time elapsed: 8.44 min\n",
      "Epoch: 010/150 | Batch 0000/0092 | Cost: 32.8457\n",
      "Time elapsed: 9.38 min\n",
      "Epoch: 011/150 | Batch 0000/0092 | Cost: 32.5064\n",
      "Time elapsed: 10.33 min\n",
      "Epoch: 012/150 | Batch 0000/0092 | Cost: 31.5168\n",
      "Time elapsed: 11.28 min\n",
      "Epoch: 013/150 | Batch 0000/0092 | Cost: 29.1672\n",
      "Time elapsed: 12.23 min\n",
      "Epoch: 014/150 | Batch 0000/0092 | Cost: 29.7407\n",
      "Time elapsed: 13.18 min\n",
      "Epoch: 015/150 | Batch 0000/0092 | Cost: 30.3941\n",
      "Time elapsed: 14.12 min\n",
      "Epoch: 016/150 | Batch 0000/0092 | Cost: 26.1868\n",
      "Time elapsed: 15.06 min\n",
      "Epoch: 017/150 | Batch 0000/0092 | Cost: 28.6050\n",
      "Time elapsed: 16.01 min\n",
      "Epoch: 018/150 | Batch 0000/0092 | Cost: 28.7208\n",
      "Time elapsed: 16.95 min\n",
      "Epoch: 019/150 | Batch 0000/0092 | Cost: 27.9524\n",
      "Time elapsed: 17.88 min\n",
      "Epoch: 020/150 | Batch 0000/0092 | Cost: 23.9113\n",
      "Time elapsed: 18.82 min\n",
      "Epoch: 021/150 | Batch 0000/0092 | Cost: 24.4436\n",
      "Time elapsed: 19.78 min\n",
      "Epoch: 022/150 | Batch 0000/0092 | Cost: 23.9554\n",
      "Time elapsed: 20.72 min\n",
      "Epoch: 023/150 | Batch 0000/0092 | Cost: 20.7829\n",
      "Time elapsed: 21.68 min\n",
      "Epoch: 024/150 | Batch 0000/0092 | Cost: 22.3296\n",
      "Time elapsed: 22.63 min\n",
      "Epoch: 025/150 | Batch 0000/0092 | Cost: 21.1909\n",
      "Time elapsed: 23.57 min\n",
      "Epoch: 026/150 | Batch 0000/0092 | Cost: 21.9036\n",
      "Time elapsed: 24.53 min\n",
      "Epoch: 027/150 | Batch 0000/0092 | Cost: 20.2870\n",
      "Time elapsed: 25.49 min\n",
      "Epoch: 028/150 | Batch 0000/0092 | Cost: 20.3275\n",
      "Time elapsed: 26.44 min\n",
      "Epoch: 029/150 | Batch 0000/0092 | Cost: 20.5857\n",
      "Time elapsed: 27.39 min\n",
      "Epoch: 030/150 | Batch 0000/0092 | Cost: 20.6721\n",
      "Time elapsed: 28.35 min\n",
      "Epoch: 031/150 | Batch 0000/0092 | Cost: 21.0904\n",
      "Time elapsed: 29.30 min\n",
      "Epoch: 032/150 | Batch 0000/0092 | Cost: 16.7851\n",
      "Time elapsed: 30.25 min\n",
      "Epoch: 033/150 | Batch 0000/0092 | Cost: 17.6401\n",
      "Time elapsed: 31.21 min\n",
      "Epoch: 034/150 | Batch 0000/0092 | Cost: 15.3736\n",
      "Time elapsed: 32.16 min\n",
      "Epoch: 035/150 | Batch 0000/0092 | Cost: 17.7772\n",
      "Time elapsed: 33.10 min\n",
      "Epoch: 036/150 | Batch 0000/0092 | Cost: 17.6568\n",
      "Time elapsed: 34.04 min\n",
      "Epoch: 037/150 | Batch 0000/0092 | Cost: 18.5792\n",
      "Time elapsed: 34.97 min\n",
      "Epoch: 038/150 | Batch 0000/0092 | Cost: 15.3883\n",
      "Time elapsed: 35.90 min\n",
      "Epoch: 039/150 | Batch 0000/0092 | Cost: 14.2040\n",
      "Time elapsed: 36.84 min\n",
      "Epoch: 040/150 | Batch 0000/0092 | Cost: 14.3120\n",
      "Time elapsed: 37.77 min\n",
      "Epoch: 041/150 | Batch 0000/0092 | Cost: 15.3872\n",
      "Time elapsed: 38.71 min\n",
      "Epoch: 042/150 | Batch 0000/0092 | Cost: 14.0422\n",
      "Time elapsed: 39.64 min\n",
      "Epoch: 043/150 | Batch 0000/0092 | Cost: 14.3585\n",
      "Time elapsed: 40.59 min\n",
      "Epoch: 044/150 | Batch 0000/0092 | Cost: 14.6751\n",
      "Time elapsed: 41.53 min\n",
      "Epoch: 045/150 | Batch 0000/0092 | Cost: 11.8405\n",
      "Time elapsed: 42.50 min\n",
      "Epoch: 046/150 | Batch 0000/0092 | Cost: 11.0839\n",
      "Time elapsed: 43.45 min\n",
      "Epoch: 047/150 | Batch 0000/0092 | Cost: 12.4769\n",
      "Time elapsed: 44.40 min\n",
      "Epoch: 048/150 | Batch 0000/0092 | Cost: 12.0954\n",
      "Time elapsed: 45.35 min\n",
      "Epoch: 049/150 | Batch 0000/0092 | Cost: 12.3591\n",
      "Time elapsed: 46.30 min\n",
      "Epoch: 050/150 | Batch 0000/0092 | Cost: 11.3061\n",
      "Time elapsed: 47.25 min\n",
      "Epoch: 051/150 | Batch 0000/0092 | Cost: 10.1474\n",
      "Time elapsed: 48.19 min\n",
      "Epoch: 052/150 | Batch 0000/0092 | Cost: 9.5122\n",
      "Time elapsed: 49.14 min\n",
      "Epoch: 053/150 | Batch 0000/0092 | Cost: 10.0264\n",
      "Time elapsed: 50.08 min\n",
      "Epoch: 054/150 | Batch 0000/0092 | Cost: 9.0709\n",
      "Time elapsed: 51.03 min\n",
      "Epoch: 055/150 | Batch 0000/0092 | Cost: 8.8659\n",
      "Time elapsed: 51.97 min\n",
      "Epoch: 056/150 | Batch 0000/0092 | Cost: 9.0466\n",
      "Time elapsed: 52.92 min\n",
      "Epoch: 057/150 | Batch 0000/0092 | Cost: 8.6440\n",
      "Time elapsed: 53.87 min\n",
      "Epoch: 058/150 | Batch 0000/0092 | Cost: 9.7978\n",
      "Time elapsed: 54.82 min\n",
      "Epoch: 059/150 | Batch 0000/0092 | Cost: 9.1187\n",
      "Time elapsed: 55.78 min\n",
      "Epoch: 060/150 | Batch 0000/0092 | Cost: 8.0830\n",
      "Time elapsed: 56.73 min\n",
      "Epoch: 061/150 | Batch 0000/0092 | Cost: 7.3659\n",
      "Time elapsed: 57.67 min\n",
      "Epoch: 062/150 | Batch 0000/0092 | Cost: 7.4319\n",
      "Time elapsed: 58.62 min\n",
      "Epoch: 063/150 | Batch 0000/0092 | Cost: 7.5847\n",
      "Time elapsed: 59.55 min\n",
      "Epoch: 064/150 | Batch 0000/0092 | Cost: 6.5518\n",
      "Time elapsed: 60.51 min\n",
      "Epoch: 065/150 | Batch 0000/0092 | Cost: 7.4076\n",
      "Time elapsed: 61.44 min\n",
      "Epoch: 066/150 | Batch 0000/0092 | Cost: 7.5390\n",
      "Time elapsed: 62.38 min\n",
      "Epoch: 067/150 | Batch 0000/0092 | Cost: 6.8755\n",
      "Time elapsed: 63.33 min\n",
      "Epoch: 068/150 | Batch 0000/0092 | Cost: 5.7859\n",
      "Time elapsed: 64.27 min\n",
      "Epoch: 069/150 | Batch 0000/0092 | Cost: 6.5447\n",
      "Time elapsed: 65.21 min\n",
      "Epoch: 070/150 | Batch 0000/0092 | Cost: 8.7847\n",
      "Time elapsed: 66.14 min\n",
      "Epoch: 071/150 | Batch 0000/0092 | Cost: 5.4289\n",
      "Time elapsed: 67.08 min\n",
      "Epoch: 072/150 | Batch 0000/0092 | Cost: 7.3215\n",
      "Time elapsed: 68.02 min\n",
      "Epoch: 073/150 | Batch 0000/0092 | Cost: 5.3592\n",
      "Time elapsed: 68.96 min\n",
      "Epoch: 074/150 | Batch 0000/0092 | Cost: 6.3312\n",
      "Time elapsed: 69.91 min\n",
      "Epoch: 075/150 | Batch 0000/0092 | Cost: 6.5182\n",
      "Time elapsed: 70.85 min\n",
      "Epoch: 076/150 | Batch 0000/0092 | Cost: 5.0352\n",
      "Time elapsed: 71.79 min\n",
      "Epoch: 077/150 | Batch 0000/0092 | Cost: 6.1928\n",
      "Time elapsed: 72.72 min\n",
      "Epoch: 078/150 | Batch 0000/0092 | Cost: 4.3198\n",
      "Time elapsed: 73.66 min\n",
      "Epoch: 079/150 | Batch 0000/0092 | Cost: 4.4914\n",
      "Time elapsed: 74.59 min\n",
      "Epoch: 080/150 | Batch 0000/0092 | Cost: 4.5828\n",
      "Time elapsed: 75.52 min\n",
      "Epoch: 081/150 | Batch 0000/0092 | Cost: 5.8475\n",
      "Time elapsed: 76.45 min\n",
      "Epoch: 082/150 | Batch 0000/0092 | Cost: 4.5677\n",
      "Time elapsed: 77.38 min\n",
      "Epoch: 083/150 | Batch 0000/0092 | Cost: 4.7913\n",
      "Time elapsed: 78.31 min\n",
      "Epoch: 084/150 | Batch 0000/0092 | Cost: 4.3687\n",
      "Time elapsed: 79.24 min\n",
      "Epoch: 085/150 | Batch 0000/0092 | Cost: 5.0481\n",
      "Time elapsed: 80.17 min\n",
      "Epoch: 086/150 | Batch 0000/0092 | Cost: 4.0501\n",
      "Time elapsed: 81.10 min\n",
      "Epoch: 087/150 | Batch 0000/0092 | Cost: 4.0695\n",
      "Time elapsed: 82.03 min\n",
      "Epoch: 088/150 | Batch 0000/0092 | Cost: 4.5136\n",
      "Time elapsed: 82.96 min\n",
      "Epoch: 089/150 | Batch 0000/0092 | Cost: 3.8159\n",
      "Time elapsed: 83.89 min\n",
      "Epoch: 090/150 | Batch 0000/0092 | Cost: 4.0424\n",
      "Time elapsed: 84.82 min\n",
      "Epoch: 091/150 | Batch 0000/0092 | Cost: 3.9980\n",
      "Time elapsed: 85.75 min\n",
      "Epoch: 092/150 | Batch 0000/0092 | Cost: 3.6338\n",
      "Time elapsed: 86.68 min\n",
      "Epoch: 093/150 | Batch 0000/0092 | Cost: 3.8388\n",
      "Time elapsed: 87.61 min\n",
      "Epoch: 094/150 | Batch 0000/0092 | Cost: 3.3051\n",
      "Time elapsed: 88.54 min\n",
      "Epoch: 095/150 | Batch 0000/0092 | Cost: 3.4325\n",
      "Time elapsed: 89.47 min\n",
      "Epoch: 096/150 | Batch 0000/0092 | Cost: 3.1995\n",
      "Time elapsed: 90.40 min\n",
      "Epoch: 097/150 | Batch 0000/0092 | Cost: 4.0571\n",
      "Time elapsed: 91.33 min\n",
      "Epoch: 098/150 | Batch 0000/0092 | Cost: 3.4636\n",
      "Time elapsed: 92.25 min\n",
      "Epoch: 099/150 | Batch 0000/0092 | Cost: 3.0544\n",
      "Time elapsed: 93.18 min\n",
      "Epoch: 100/150 | Batch 0000/0092 | Cost: 2.8106\n",
      "Time elapsed: 94.10 min\n",
      "Epoch: 101/150 | Batch 0000/0092 | Cost: 3.0885\n",
      "Time elapsed: 95.03 min\n",
      "Epoch: 102/150 | Batch 0000/0092 | Cost: 2.8910\n",
      "Time elapsed: 95.96 min\n",
      "Epoch: 103/150 | Batch 0000/0092 | Cost: 3.0126\n",
      "Time elapsed: 96.90 min\n",
      "Epoch: 104/150 | Batch 0000/0092 | Cost: 2.8797\n",
      "Time elapsed: 97.83 min\n",
      "Epoch: 105/150 | Batch 0000/0092 | Cost: 2.7753\n",
      "Time elapsed: 98.76 min\n",
      "Epoch: 106/150 | Batch 0000/0092 | Cost: 2.9361\n",
      "Time elapsed: 99.69 min\n",
      "Epoch: 107/150 | Batch 0000/0092 | Cost: 2.4497\n",
      "Time elapsed: 100.62 min\n",
      "Epoch: 108/150 | Batch 0000/0092 | Cost: 2.6242\n",
      "Time elapsed: 101.55 min\n",
      "Epoch: 109/150 | Batch 0000/0092 | Cost: 2.4673\n",
      "Time elapsed: 102.48 min\n",
      "Epoch: 110/150 | Batch 0000/0092 | Cost: 2.6668\n",
      "Time elapsed: 103.40 min\n",
      "Epoch: 111/150 | Batch 0000/0092 | Cost: 2.2719\n",
      "Time elapsed: 104.33 min\n",
      "Epoch: 112/150 | Batch 0000/0092 | Cost: 2.5014\n",
      "Time elapsed: 105.26 min\n",
      "Epoch: 113/150 | Batch 0000/0092 | Cost: 2.4812\n",
      "Time elapsed: 106.19 min\n",
      "Epoch: 114/150 | Batch 0000/0092 | Cost: 2.3502\n",
      "Time elapsed: 107.11 min\n",
      "Epoch: 115/150 | Batch 0000/0092 | Cost: 2.3428\n",
      "Time elapsed: 108.05 min\n",
      "Epoch: 116/150 | Batch 0000/0092 | Cost: 2.3403\n",
      "Time elapsed: 108.98 min\n",
      "Epoch: 117/150 | Batch 0000/0092 | Cost: 2.6313\n",
      "Time elapsed: 109.90 min\n",
      "Epoch: 118/150 | Batch 0000/0092 | Cost: 2.0320\n",
      "Time elapsed: 110.83 min\n",
      "Epoch: 119/150 | Batch 0000/0092 | Cost: 1.8235\n",
      "Time elapsed: 111.76 min\n",
      "Epoch: 120/150 | Batch 0000/0092 | Cost: 2.0651\n",
      "Time elapsed: 112.69 min\n",
      "Epoch: 121/150 | Batch 0000/0092 | Cost: 2.1958\n",
      "Time elapsed: 113.61 min\n",
      "Epoch: 122/150 | Batch 0000/0092 | Cost: 1.9154\n",
      "Time elapsed: 114.54 min\n",
      "Epoch: 123/150 | Batch 0000/0092 | Cost: 1.8475\n",
      "Time elapsed: 115.47 min\n",
      "Epoch: 124/150 | Batch 0000/0092 | Cost: 2.0797\n",
      "Time elapsed: 116.40 min\n",
      "Epoch: 125/150 | Batch 0000/0092 | Cost: 2.4172\n",
      "Time elapsed: 117.33 min\n",
      "Epoch: 126/150 | Batch 0000/0092 | Cost: 1.8864\n",
      "Time elapsed: 118.26 min\n",
      "Epoch: 127/150 | Batch 0000/0092 | Cost: 1.7855\n",
      "Time elapsed: 119.19 min\n",
      "Epoch: 128/150 | Batch 0000/0092 | Cost: 1.7424\n",
      "Time elapsed: 120.12 min\n",
      "Epoch: 129/150 | Batch 0000/0092 | Cost: 1.7705\n",
      "Time elapsed: 121.05 min\n",
      "Epoch: 130/150 | Batch 0000/0092 | Cost: 1.8574\n",
      "Time elapsed: 121.97 min\n",
      "Epoch: 131/150 | Batch 0000/0092 | Cost: 1.8506\n",
      "Time elapsed: 122.90 min\n",
      "Epoch: 132/150 | Batch 0000/0092 | Cost: 1.8032\n",
      "Time elapsed: 123.83 min\n",
      "Epoch: 133/150 | Batch 0000/0092 | Cost: 2.3258\n",
      "Time elapsed: 124.76 min\n",
      "Epoch: 134/150 | Batch 0000/0092 | Cost: 1.5483\n",
      "Time elapsed: 125.69 min\n",
      "Epoch: 135/150 | Batch 0000/0092 | Cost: 1.6174\n",
      "Time elapsed: 126.62 min\n",
      "Epoch: 136/150 | Batch 0000/0092 | Cost: 1.8305\n",
      "Time elapsed: 127.55 min\n",
      "Epoch: 137/150 | Batch 0000/0092 | Cost: 1.6682\n",
      "Time elapsed: 128.48 min\n",
      "Epoch: 138/150 | Batch 0000/0092 | Cost: 1.4051\n",
      "Time elapsed: 129.41 min\n",
      "Epoch: 139/150 | Batch 0000/0092 | Cost: 1.5446\n",
      "Time elapsed: 130.34 min\n",
      "Epoch: 140/150 | Batch 0000/0092 | Cost: 1.5014\n",
      "Time elapsed: 131.27 min\n",
      "Epoch: 141/150 | Batch 0000/0092 | Cost: 1.5583\n",
      "Time elapsed: 132.20 min\n",
      "Epoch: 142/150 | Batch 0000/0092 | Cost: 1.5508\n",
      "Time elapsed: 133.13 min\n",
      "Epoch: 143/150 | Batch 0000/0092 | Cost: 1.4848\n",
      "Time elapsed: 134.06 min\n",
      "Epoch: 144/150 | Batch 0000/0092 | Cost: 1.4201\n",
      "Time elapsed: 134.99 min\n",
      "Epoch: 145/150 | Batch 0000/0092 | Cost: 1.5778\n",
      "Time elapsed: 135.92 min\n",
      "Epoch: 146/150 | Batch 0000/0092 | Cost: 1.5876\n",
      "Time elapsed: 136.84 min\n",
      "Epoch: 147/150 | Batch 0000/0092 | Cost: 1.4196\n",
      "Time elapsed: 137.77 min\n",
      "Epoch: 148/150 | Batch 0000/0092 | Cost: 1.4015\n",
      "Time elapsed: 138.70 min\n",
      "Epoch: 149/150 | Batch 0000/0092 | Cost: 1.5134\n",
      "Time elapsed: 139.63 min\n",
      "Epoch: 150/150 | Batch 0000/0092 | Cost: 1.2708\n",
      "Time elapsed: 140.56 min\n"
     ]
    }
   ],
   "source": [
    "def compute_mae_and_mse(model, data_loader):\n",
    "    mae, mse, num_examples = torch.tensor([0.]), torch.tensor([0.]), 0\n",
    "    for features, targets in data_loader:\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.float().to(DEVICE)\n",
    "        logits, probas, predictions = model(features)\n",
    "        assert len(targets.size()) == 1\n",
    "        assert len(predictions.size()) == 1\n",
    "        predicted_labels = torch.round(predictions).float()\n",
    "        num_examples += targets.size(0)\n",
    "        mae += torch.abs(predicted_labels - targets).sum()\n",
    "        mse += torch.sum((predicted_labels - targets)**2)\n",
    "    mae = mae / num_examples\n",
    "    mse = mse / num_examples\n",
    "    return mae, mse\n",
    "\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "\n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "\n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "\n",
    "        # FORWARD AND BACK PROP\n",
    "        logits, probas, predictions = model(features)\n",
    "        assert len(targets.size()) == 1\n",
    "        assert len(predictions.size()) == 1\n",
    "        cost = cost_fn(targets, predictions)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        cost.backward()\n",
    "\n",
    "        # UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "\n",
    "        # LOGGING\n",
    "        if not batch_idx % 150:\n",
    "            s = ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'\n",
    "                 % (epoch+1, NUM_EPOCHS, batch_idx,\n",
    "                     len(train_dataset)//BATCH_SIZE, cost))\n",
    "            print(s)\n",
    "\n",
    "    s = 'Time elapsed: %.2f min' % ((time.time() - start_time)/60)\n",
    "    print(s)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE/RMSE: | Train: 0.90/1.23 | Test: 3.40/4.61\n",
      "Total Training Time: 141.35 min\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.set_grad_enabled(False):  # save memory during inference\n",
    "\n",
    "    train_mae, train_mse = compute_mae_and_mse(model, train_loader)\n",
    "    test_mae, test_mse = compute_mae_and_mse(model, test_loader)\n",
    "\n",
    "    s = 'MAE/RMSE: | Train: %.2f/%.2f | Test: %.2f/%.2f' % (\n",
    "        train_mae, torch.sqrt(train_mse), test_mae, torch.sqrt(test_mse))\n",
    "    print(s)\n",
    "\n",
    "s = 'Total Training Time: %.2f min' % ((time.time() - start_time)/60)\n",
    "print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "pandas      0.23.4\n",
      "torch       1.1.0\n",
      "PIL.Image   5.3.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}