{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convolutional Autoencoder with Deconvolutions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A convolutional autoencoder using deconvolutional layers that compresses 768-pixel MNIST images down to a 7x7x8 (392 pixel) representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:0\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', device)\n", "\n", "# Hyperparameters\n", "random_seed = 456\n", "learning_rate = 0.005\n", "num_epochs = 10\n", "batch_size = 128\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class ConvolutionalAutoencoder(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(ConvolutionalAutoencoder, self).__init__()\n", " \n", " # calculate same padding:\n", " # (w - k + 2*p)/s + 1 = o\n", " # => p = (s(o-1) - w + k)/2\n", " \n", " ### ENCODER\n", " \n", " # 28x28x1 => 28x28x4\n", " self.conv_1 = torch.nn.Conv2d(in_channels=1,\n", " out_channels=4,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " # (1(28-1) - 28 + 3) / 2 = 1\n", " padding=1) \n", " # 28x28x4 => 14x14x4 \n", " self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " # (2(14-1) - 28 + 2) / 2 = 0\n", " padding=0) \n", " # 14x14x4 => 14x14x8\n", " self.conv_2 = torch.nn.Conv2d(in_channels=4,\n", " out_channels=8,\n", " kernel_size=(3, 3),\n", " stride=(1, 1),\n", " # (1(14-1) - 14 + 3) / 2 = 1\n", " padding=1) \n", " # 14x14x8 => 7x7x8 \n", " self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n", " stride=(2, 2),\n", " # (2(7-1) - 14 + 2) / 2 = 0\n", " padding=0)\n", " \n", " ### DECODER\n", " \n", " # 7x7x8 => 15x15x4 \n", " self.deconv_1 = torch.nn.ConvTranspose2d(in_channels=8,\n", " out_channels=4,\n", " kernel_size=(3, 3),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " # 15x15x4 => 31x31x1 \n", " self.deconv_2 = torch.nn.ConvTranspose2d(in_channels=4,\n", " out_channels=1,\n", " kernel_size=(3, 3),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " def forward(self, x):\n", " \n", " ### ENCODER\n", " x = self.conv_1(x)\n", " x = F.leaky_relu(x)\n", " x = self.pool_1(x)\n", " x = self.conv_2(x)\n", " x = F.leaky_relu(x)\n", " x = self.pool_2(x)\n", " \n", " ### DECODER\n", " x = self.deconv_1(x)\n", " x = F.leaky_relu(x)\n", " x = self.deconv_2(x)\n", " x = F.leaky_relu(x)\n", " logits = x[:, :, 2:30, 2:30]\n", " probas = torch.sigmoid(logits)\n", " return logits, probas\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = ConvolutionalAutoencoder()\n", "model = model.to(device)\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/010 | Batch 000/468 | Cost: 0.7191\n", "Epoch: 001/010 | Batch 050/468 | Cost: 0.6911\n", "Epoch: 001/010 | Batch 100/468 | Cost: 0.5994\n", "Epoch: 001/010 | Batch 150/468 | Cost: 0.3594\n", "Epoch: 001/010 | Batch 200/468 | Cost: 0.3035\n", "Epoch: 001/010 | Batch 250/468 | Cost: 0.2480\n", "Epoch: 001/010 | Batch 300/468 | Cost: 0.2085\n", "Epoch: 001/010 | Batch 350/468 | Cost: 0.1780\n", "Epoch: 001/010 | Batch 400/468 | Cost: 0.1586\n", "Epoch: 001/010 | Batch 450/468 | Cost: 0.1532\n", "Epoch: 002/010 | Batch 000/468 | Cost: 0.1442\n", "Epoch: 002/010 | Batch 050/468 | Cost: 0.1366\n", "Epoch: 002/010 | Batch 100/468 | Cost: 0.1362\n", "Epoch: 002/010 | Batch 150/468 | Cost: 0.1293\n", "Epoch: 002/010 | Batch 200/468 | Cost: 0.1286\n", "Epoch: 002/010 | Batch 250/468 | Cost: 0.1244\n", "Epoch: 002/010 | Batch 300/468 | Cost: 0.1226\n", "Epoch: 002/010 | Batch 350/468 | Cost: 0.1240\n", "Epoch: 002/010 | Batch 400/468 | Cost: 0.1239\n", "Epoch: 002/010 | Batch 450/468 | Cost: 0.1193\n", "Epoch: 003/010 | Batch 000/468 | Cost: 0.1196\n", "Epoch: 003/010 | Batch 050/468 | Cost: 0.1197\n", "Epoch: 003/010 | Batch 100/468 | Cost: 0.1217\n", "Epoch: 003/010 | Batch 150/468 | Cost: 0.1167\n", "Epoch: 003/010 | Batch 200/468 | Cost: 0.1115\n", "Epoch: 003/010 | Batch 250/468 | Cost: 0.1144\n", "Epoch: 003/010 | Batch 300/468 | Cost: 0.1092\n", "Epoch: 003/010 | Batch 350/468 | Cost: 0.1164\n", "Epoch: 003/010 | Batch 400/468 | Cost: 0.1141\n", "Epoch: 003/010 | Batch 450/468 | Cost: 0.1071\n", "Epoch: 004/010 | Batch 000/468 | Cost: 0.1121\n", "Epoch: 004/010 | Batch 050/468 | Cost: 0.1130\n", "Epoch: 004/010 | Batch 100/468 | Cost: 0.1043\n", "Epoch: 004/010 | Batch 150/468 | Cost: 0.1098\n", "Epoch: 004/010 | Batch 200/468 | Cost: 0.1104\n", "Epoch: 004/010 | Batch 250/468 | Cost: 0.1095\n", "Epoch: 004/010 | Batch 300/468 | Cost: 0.1105\n", "Epoch: 004/010 | Batch 350/468 | Cost: 0.1088\n", "Epoch: 004/010 | Batch 400/468 | Cost: 0.1040\n", "Epoch: 004/010 | Batch 450/468 | Cost: 0.1098\n", "Epoch: 005/010 | Batch 000/468 | Cost: 0.1045\n", "Epoch: 005/010 | Batch 050/468 | Cost: 0.1030\n", "Epoch: 005/010 | Batch 100/468 | Cost: 0.1029\n", "Epoch: 005/010 | Batch 150/468 | Cost: 0.1063\n", "Epoch: 005/010 | Batch 200/468 | Cost: 0.1056\n", "Epoch: 005/010 | Batch 250/468 | Cost: 0.1046\n", "Epoch: 005/010 | Batch 300/468 | Cost: 0.1074\n", "Epoch: 005/010 | Batch 350/468 | Cost: 0.1062\n", "Epoch: 005/010 | Batch 400/468 | Cost: 0.1029\n", "Epoch: 005/010 | Batch 450/468 | Cost: 0.1074\n", "Epoch: 006/010 | Batch 000/468 | Cost: 0.1051\n", "Epoch: 006/010 | Batch 050/468 | Cost: 0.0983\n", "Epoch: 006/010 | Batch 100/468 | Cost: 0.1031\n", "Epoch: 006/010 | Batch 150/468 | Cost: 0.1060\n", "Epoch: 006/010 | Batch 200/468 | Cost: 0.1044\n", "Epoch: 006/010 | Batch 250/468 | Cost: 0.1013\n", "Epoch: 006/010 | Batch 300/468 | Cost: 0.0992\n", "Epoch: 006/010 | Batch 350/468 | Cost: 0.1010\n", "Epoch: 006/010 | Batch 400/468 | Cost: 0.1020\n", "Epoch: 006/010 | Batch 450/468 | Cost: 0.1047\n", "Epoch: 007/010 | Batch 000/468 | Cost: 0.0979\n", "Epoch: 007/010 | Batch 050/468 | Cost: 0.0978\n", "Epoch: 007/010 | Batch 100/468 | Cost: 0.1001\n", "Epoch: 007/010 | Batch 150/468 | Cost: 0.1023\n", "Epoch: 007/010 | Batch 200/468 | Cost: 0.1008\n", "Epoch: 007/010 | Batch 250/468 | Cost: 0.0943\n", "Epoch: 007/010 | Batch 300/468 | Cost: 0.0968\n", "Epoch: 007/010 | Batch 350/468 | Cost: 0.1017\n", "Epoch: 007/010 | Batch 400/468 | Cost: 0.0988\n", "Epoch: 007/010 | Batch 450/468 | Cost: 0.0992\n", "Epoch: 008/010 | Batch 000/468 | Cost: 0.1015\n", "Epoch: 008/010 | Batch 050/468 | Cost: 0.0995\n", "Epoch: 008/010 | Batch 100/468 | Cost: 0.0988\n", "Epoch: 008/010 | Batch 150/468 | Cost: 0.0980\n", "Epoch: 008/010 | Batch 200/468 | Cost: 0.0986\n", "Epoch: 008/010 | Batch 250/468 | Cost: 0.0958\n", "Epoch: 008/010 | Batch 300/468 | Cost: 0.0958\n", "Epoch: 008/010 | Batch 350/468 | Cost: 0.0928\n", "Epoch: 008/010 | Batch 400/468 | Cost: 0.0986\n", "Epoch: 008/010 | Batch 450/468 | Cost: 0.0972\n", "Epoch: 009/010 | Batch 000/468 | Cost: 0.0985\n", "Epoch: 009/010 | Batch 050/468 | Cost: 0.0958\n", "Epoch: 009/010 | Batch 100/468 | Cost: 0.1002\n", "Epoch: 009/010 | Batch 150/468 | Cost: 0.0980\n", "Epoch: 009/010 | Batch 200/468 | Cost: 0.0973\n", "Epoch: 009/010 | Batch 250/468 | Cost: 0.0966\n", "Epoch: 009/010 | Batch 300/468 | Cost: 0.0948\n", "Epoch: 009/010 | Batch 350/468 | Cost: 0.0983\n", "Epoch: 009/010 | Batch 400/468 | Cost: 0.0986\n", "Epoch: 009/010 | Batch 450/468 | Cost: 0.0976\n", "Epoch: 010/010 | Batch 000/468 | Cost: 0.0975\n", "Epoch: 010/010 | Batch 050/468 | Cost: 0.0971\n", "Epoch: 010/010 | Batch 100/468 | Cost: 0.0976\n", "Epoch: 010/010 | Batch 150/468 | Cost: 0.0953\n", "Epoch: 010/010 | Batch 200/468 | Cost: 0.0982\n", "Epoch: 010/010 | Batch 250/468 | Cost: 0.0964\n", "Epoch: 010/010 | Batch 300/468 | Cost: 0.1003\n", "Epoch: 010/010 | Batch 350/468 | Cost: 0.0914\n", "Epoch: 010/010 | Batch 400/468 | Cost: 0.0971\n", "Epoch: 010/010 | Batch 450/468 | Cost: 0.0959\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " # don't need labels, only the images (features)\n", " features = features.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " logits, decoded = model(features)\n", " cost = F.binary_cross_entropy_with_logits(logits, features)\n", " optimizer.zero_grad()\n", " \n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_dataset)//batch_size, cost))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "n_images = 15\n", "image_width = 28\n", "\n", "fig, axes = plt.subplots(nrows=2, ncols=n_images, \n", " sharex=True, sharey=True, figsize=(20, 2.5))\n", "orig_images = features[:n_images]\n", "decoded_images = decoded[:n_images]\n", "\n", "for i in range(n_images):\n", " for ax, img in zip(axes, [orig_images, decoded_images]):\n", " curr_img = img[i].detach().to(torch.device('cpu'))\n", " ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }