{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.9.0\n", "\n", "torch 1.3.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convolutional Autoencoder with Deconvolutions and Continuous Jaccard Distance" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A convolutional autoencoder using deconvolutional layers that compresses 768-pixel MNIST images down to a 7x7x8 (392 pixel) representation.\n", "\n", "\n", "This convolutional VAE uses a continuous Jaccard distance. I.e., given 2 vectors, $x$ and $y$:\n", "\n", "$$J(x, y)=1-\\frac{\\sum_{i} \\min \\left(x_{i}, y_{i}\\right)}{\\sum_{i} \\max \\left(x_{i}, y_{i}\\right)}$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reference:\n", " \n", "- [1] https://en.wikipedia.org/wiki/Jaccard_index" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.6275)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "\n", "def continuous_jaccard(x, y):\n", " \"\"\"\n", " Implementation of the continuous version of the\n", " Jaccard distance:\n", " 1 - [sum_i min(x_i, y_i)] / [sum_i max(x_i, y_i)]\n", " \"\"\"\n", " c = torch.cat((x.view(-1).unsqueeze(1), y.view(-1).unsqueeze(1)), dim=1)\n", "\n", " numerator = torch.sum(torch.min(c, dim=1)[0])\n", " denominator = torch.sum(torch.max(c, dim=1)[0])\n", "\n", " return 1. - numerator/denominator\n", "\n", "\n", "\n", "# Example\n", "\n", "x = torch.tensor([7, 2, 3, 4, 5, 6]).float()\n", "y = torch.tensor([1, 8, 9, 10, 11, 4]).float()\n", "\n", "continuous_jaccard(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional Imports" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:0\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "9920512it [00:02, 3410868.60it/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "32768it [00:00, 280881.47it/s] \n", "0it [00:00, ?it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n", "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "1654784it [00:00, 1928783.37it/s] \n", "8192it [00:00, 113077.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n", "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n", "Processing...\n", "Done!\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', device)\n", "\n", "# Hyperparameters\n", "random_seed = 456\n", "learning_rate = 0.005\n", "num_epochs = 10\n", "batch_size = 128\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class ConvolutionalAutoencoder(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(ConvolutionalAutoencoder, self).__init__()\n", " \n", " # calculate same padding:\n", " # (w - k + 2*p)/s + 1 = o\n", " # => p = (s(o-1) - w + k)/2\n", " \n", " ### ENCODER\n", " \n", " # 28x28x1 => 28x28x4\n", " self.conv_1 = torch.nn.Conv2d(in_channels=1,\n", " out_channels=4,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " # (1(28-1) - 28 + 3) / 2 = 1\n", " padding=1) \n", " # 28x28x4 => 14x14x4 \n", " self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " # (2(14-1) - 28 + 2) / 2 = 0\n", " padding=0) \n", " # 14x14x4 => 14x14x8\n", " self.conv_2 = torch.nn.Conv2d(in_channels=4,\n", " out_channels=8,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " # (1(14-1) - 14 + 3) / 2 = 1\n", " padding=1) \n", " # 14x14x8 => 7x7x8 \n", " self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " # (2(7-1) - 14 + 2) / 2 = 0\n", " padding=0)\n", " \n", " ### DECODER\n", " \n", " # 7x7x8 => 15x15x4 \n", " self.deconv_1 = torch.nn.ConvTranspose2d(in_channels=8,\n", " out_channels=4,\n", " kernel_size=(3, 3),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " # 15x15x4 => 31x31x1 \n", " self.deconv_2 = torch.nn.ConvTranspose2d(in_channels=4,\n", " out_channels=1,\n", " kernel_size=(3, 3),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " def forward(self, x):\n", " \n", " ### ENCODER\n", " x = self.conv_1(x)\n", " x = F.leaky_relu(x)\n", " x = self.pool_1(x)\n", " x = self.conv_2(x)\n", " x = F.leaky_relu(x)\n", " x = self.pool_2(x)\n", " \n", " ### DECODER\n", " x = self.deconv_1(x)\n", " x = F.leaky_relu(x)\n", " x = self.deconv_2(x)\n", " x = F.leaky_relu(x)\n", " logits = x[:, :, 2:30, 2:30]\n", " probas = torch.sigmoid(logits)\n", " return logits, probas\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = ConvolutionalAutoencoder()\n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 000/468 | Cost: 0.8663\n", "Epoch: 001/010 | Batch 050/468 | Cost: 0.8086\n", "Epoch: 001/010 | Batch 100/468 | Cost: 0.7729\n", "Epoch: 001/010 | Batch 150/468 | Cost: 0.7322\n", "Epoch: 001/010 | Batch 200/468 | Cost: 0.3983\n", "Epoch: 001/010 | Batch 250/468 | Cost: 0.2963\n", "Epoch: 001/010 | Batch 300/468 | Cost: 0.2927\n", "Epoch: 001/010 | Batch 350/468 | Cost: 0.2783\n", "Epoch: 001/010 | Batch 400/468 | Cost: 0.2780\n", "Epoch: 001/010 | Batch 450/468 | Cost: 0.2609\n", "Time elapsed: 0.14 min\n", "Epoch: 002/010 | Batch 000/468 | Cost: 0.2694\n", "Epoch: 002/010 | Batch 050/468 | Cost: 0.2671\n", "Epoch: 002/010 | Batch 100/468 | Cost: 0.2444\n", "Epoch: 002/010 | Batch 150/468 | Cost: 0.2378\n", "Epoch: 002/010 | Batch 200/468 | Cost: 0.2540\n", "Epoch: 002/010 | Batch 250/468 | Cost: 0.2515\n", "Epoch: 002/010 | Batch 300/468 | Cost: 0.2393\n", "Epoch: 002/010 | Batch 350/468 | Cost: 0.2528\n", "Epoch: 002/010 | Batch 400/468 | Cost: 0.2283\n", "Epoch: 002/010 | Batch 450/468 | Cost: 0.2420\n", "Time elapsed: 0.27 min\n", "Epoch: 003/010 | Batch 000/468 | Cost: 0.2317\n", "Epoch: 003/010 | Batch 050/468 | Cost: 0.2274\n", "Epoch: 003/010 | Batch 100/468 | Cost: 0.2489\n", "Epoch: 003/010 | Batch 150/468 | Cost: 0.2246\n", "Epoch: 003/010 | Batch 200/468 | Cost: 0.2178\n", "Epoch: 003/010 | Batch 250/468 | Cost: 0.2200\n", "Epoch: 003/010 | Batch 300/468 | Cost: 0.2200\n", "Epoch: 003/010 | Batch 350/468 | Cost: 0.2309\n", "Epoch: 003/010 | Batch 400/468 | Cost: 0.2215\n", "Epoch: 003/010 | Batch 450/468 | Cost: 0.2218\n", "Time elapsed: 0.40 min\n", "Epoch: 004/010 | Batch 000/468 | Cost: 0.2124\n", "Epoch: 004/010 | Batch 050/468 | Cost: 0.2191\n", "Epoch: 004/010 | Batch 100/468 | Cost: 0.2121\n", "Epoch: 004/010 | Batch 150/468 | Cost: 0.2184\n", "Epoch: 004/010 | Batch 200/468 | Cost: 0.2118\n", "Epoch: 004/010 | Batch 250/468 | Cost: 0.2090\n", "Epoch: 004/010 | Batch 300/468 | Cost: 0.2114\n", "Epoch: 004/010 | Batch 350/468 | Cost: 0.2150\n", "Epoch: 004/010 | Batch 400/468 | Cost: 0.2218\n", "Epoch: 004/010 | Batch 450/468 | Cost: 0.2015\n", "Time elapsed: 0.53 min\n", "Epoch: 005/010 | Batch 000/468 | Cost: 0.1985\n", "Epoch: 005/010 | Batch 050/468 | Cost: 0.2053\n", "Epoch: 005/010 | Batch 100/468 | Cost: 0.2067\n", "Epoch: 005/010 | Batch 150/468 | Cost: 0.2003\n", "Epoch: 005/010 | Batch 200/468 | Cost: 0.2004\n", "Epoch: 005/010 | Batch 250/468 | Cost: 0.2076\n", "Epoch: 005/010 | Batch 300/468 | Cost: 0.2006\n", "Epoch: 005/010 | Batch 350/468 | Cost: 0.2162\n", "Epoch: 005/010 | Batch 400/468 | Cost: 0.2137\n", "Epoch: 005/010 | Batch 450/468 | Cost: 0.2077\n", "Time elapsed: 0.67 min\n", "Epoch: 006/010 | Batch 000/468 | Cost: 0.1986\n", "Epoch: 006/010 | Batch 050/468 | Cost: 0.2048\n", "Epoch: 006/010 | Batch 100/468 | Cost: 0.2063\n", "Epoch: 006/010 | Batch 150/468 | Cost: 0.2069\n", "Epoch: 006/010 | Batch 200/468 | Cost: 0.2092\n", "Epoch: 006/010 | Batch 250/468 | Cost: 0.1947\n", "Epoch: 006/010 | Batch 300/468 | Cost: 0.2006\n", "Epoch: 006/010 | Batch 350/468 | Cost: 0.1927\n", "Epoch: 006/010 | Batch 400/468 | Cost: 0.2018\n", "Epoch: 006/010 | Batch 450/468 | Cost: 0.1964\n", "Time elapsed: 0.79 min\n", "Epoch: 007/010 | Batch 000/468 | Cost: 0.1809\n", "Epoch: 007/010 | Batch 050/468 | Cost: 0.1996\n", "Epoch: 007/010 | Batch 100/468 | Cost: 0.1942\n", "Epoch: 007/010 | Batch 150/468 | Cost: 0.1909\n", "Epoch: 007/010 | Batch 200/468 | Cost: 0.1894\n", "Epoch: 007/010 | Batch 250/468 | Cost: 0.1937\n", "Epoch: 007/010 | Batch 300/468 | Cost: 0.1956\n", "Epoch: 007/010 | Batch 350/468 | Cost: 0.1938\n", "Epoch: 007/010 | Batch 400/468 | Cost: 0.1963\n", "Epoch: 007/010 | Batch 450/468 | Cost: 0.2060\n", "Time elapsed: 0.92 min\n", "Epoch: 008/010 | Batch 000/468 | Cost: 0.1947\n", "Epoch: 008/010 | Batch 050/468 | Cost: 0.2044\n", "Epoch: 008/010 | Batch 100/468 | Cost: 0.1811\n", "Epoch: 008/010 | Batch 150/468 | Cost: 0.1980\n", "Epoch: 008/010 | Batch 200/468 | Cost: 0.1794\n", "Epoch: 008/010 | Batch 250/468 | Cost: 0.2008\n", "Epoch: 008/010 | Batch 300/468 | Cost: 0.1949\n", "Epoch: 008/010 | Batch 350/468 | Cost: 0.1843\n", "Epoch: 008/010 | Batch 400/468 | Cost: 0.1942\n", "Epoch: 008/010 | Batch 450/468 | Cost: 0.1932\n", "Time elapsed: 1.05 min\n", "Epoch: 009/010 | Batch 000/468 | Cost: 0.1901\n", "Epoch: 009/010 | Batch 050/468 | Cost: 0.1894\n", "Epoch: 009/010 | Batch 100/468 | Cost: 0.1976\n", "Epoch: 009/010 | Batch 150/468 | Cost: 0.1935\n", "Epoch: 009/010 | Batch 200/468 | Cost: 0.1949\n", "Epoch: 009/010 | Batch 250/468 | Cost: 0.1921\n", "Epoch: 009/010 | Batch 300/468 | Cost: 0.1917\n", "Epoch: 009/010 | Batch 350/468 | Cost: 0.1900\n", "Epoch: 009/010 | Batch 400/468 | Cost: 0.1913\n", "Epoch: 009/010 | Batch 450/468 | Cost: 0.1815\n", "Time elapsed: 1.19 min\n", "Epoch: 010/010 | Batch 000/468 | Cost: 0.1845\n", "Epoch: 010/010 | Batch 050/468 | Cost: 0.1910\n", "Epoch: 010/010 | Batch 100/468 | Cost: 0.1929\n", "Epoch: 010/010 | Batch 150/468 | Cost: 0.1919\n", "Epoch: 010/010 | Batch 200/468 | Cost: 0.1822\n", "Epoch: 010/010 | Batch 250/468 | Cost: 0.1974\n", "Epoch: 010/010 | Batch 300/468 | Cost: 0.1919\n", "Epoch: 010/010 | Batch 350/468 | Cost: 0.1750\n", "Epoch: 010/010 | Batch 400/468 | Cost: 0.1879\n", "Epoch: 010/010 | Batch 450/468 | Cost: 0.1785\n", "Time elapsed: 1.32 min\n", "Total Training Time: 1.32 min\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " # don't need labels, only the images (features)\n", " features = features.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, decoded = model(features)\n", " #cost = F.binary_cross_entropy_with_logits(logits, features)\n", " cost = continuous_jaccard(features, decoded)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_dataset)//batch_size, cost))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "