{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1czVdIlqnImH" }, "source": [ "# Pix2Pix" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1KD3ZgLs80vY" }, "source": [ "### Goals\n", "In this notebook, you will write a generative model based on the paper [*Image-to-Image Translation with Conditional Adversarial Networks*](https://arxiv.org/abs/1611.07004) by Isola et al. 2017, also known as Pix2Pix.\n", "\n", "You will be training a model that can convert aerial satellite imagery (\"input\") into map routes (\"output\"), as was done in the original paper. Since the architecture for the generator is a U-Net, which you've already implemented (with minor changes), the emphasis of the assignment will be on the loss function. So that you can see outputs more quickly, you'll be able to see your model train starting from a pre-trained checkpoint - but feel free to train it from scratch on your own too.\n", "\n", "\n", "![pix2pix example](pix2pix_ex.png)\n", "\n", "\n", "\n", "\n", "### Learning Objectives\n", "1. Implement the loss of a Pix2Pix model that differentiates it from a supervised U-Net.\n", "2. Observe the change in generator priorities as the Pix2Pix generator trains, changing its emphasis from reconstruction to realism.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wU8DDM6l9rZb" }, "source": [ "## Getting Started\n", "You will start by importing libraries, defining a visualization function, and getting the pre-trained Pix2Pix checkpoint. You will also be provided with the U-Net code for the Pix2Pix generator." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "JfkorNJrnmNO" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from tqdm.auto import tqdm\n", "from torchvision import transforms\n", "from torchvision.datasets import VOCSegmentation\n", "from torchvision.utils import make_grid\n", "from torch.utils.data import DataLoader\n", "import matplotlib.pyplot as plt\n", "torch.manual_seed(0)\n", "\n", "def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):\n", " '''\n", " Function for visualizing images: Given a tensor of images, number of images, and\n", " size per image, plots and prints the images in an uniform grid.\n", " '''\n", " image_shifted = image_tensor\n", " image_unflat = image_shifted.detach().cpu().view(-1, *size)\n", " image_grid = make_grid(image_unflat[:num_images], nrow=5)\n", " plt.imshow(image_grid.permute(1, 2, 0).squeeze())\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NjFyvNTG1CqY" }, "source": [ "#### U-Net Code\n", "\n", "The U-Net code will be much like the code you wrote for the last assignment, but with optional dropout and batchnorm. The structure is changed slightly for Pix2Pix, so that the final image is closer in size to the input image. Feel free to investigate the code if you're interested!" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "xvY4ZNyUviY9" }, "outputs": [], "source": [ "def crop(image, new_shape):\n", " '''\n", " Function for cropping an image tensor: Given an image tensor and the new shape,\n", " crops to the center pixels (assumes that the input's size and the new size are\n", " even numbers).\n", " Parameters:\n", " image: image tensor of shape (batch size, channels, height, width)\n", " new_shape: a torch.Size object with the shape you want x to have\n", " '''\n", " middle_height = image.shape[2] // 2\n", " middle_width = image.shape[3] // 2\n", " starting_height = middle_height - new_shape[2] // 2\n", " final_height = starting_height + new_shape[2]\n", " starting_width = middle_width - new_shape[3] // 2\n", " final_width = starting_width + new_shape[3]\n", " cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]\n", " return cropped_image\n", "\n", "class ContractingBlock(nn.Module):\n", " '''\n", " ContractingBlock Class\n", " Performs two convolutions followed by a max pool operation.\n", " Values:\n", " input_channels: the number of channels to expect from a given input\n", " '''\n", " def __init__(self, input_channels, use_dropout=False, use_bn=True):\n", " super(ContractingBlock, self).__init__()\n", " self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)\n", " self.activation = nn.LeakyReLU(0.2)\n", " self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n", " if use_bn:\n", " self.batchnorm = nn.BatchNorm2d(input_channels * 2)\n", " self.use_bn = use_bn\n", " if use_dropout:\n", " self.dropout = nn.Dropout()\n", " self.use_dropout = use_dropout\n", "\n", " def forward(self, x):\n", " '''\n", " Function for completing a forward pass of ContractingBlock: \n", " Given an image tensor, completes a contracting block and returns the transformed tensor.\n", " Parameters:\n", " x: image tensor of shape (batch size, channels, height, width)\n", " '''\n", " x = self.conv1(x)\n", " if self.use_bn:\n", " x = self.batchnorm(x)\n", " if self.use_dropout:\n", " x = self.dropout(x)\n", " x = self.activation(x)\n", " x = self.conv2(x)\n", " if self.use_bn:\n", " x = self.batchnorm(x)\n", " if self.use_dropout:\n", " x = self.dropout(x)\n", " x = self.activation(x)\n", " x = self.maxpool(x)\n", " return x\n", "\n", "class ExpandingBlock(nn.Module):\n", " '''\n", " ExpandingBlock Class:\n", " Performs an upsampling, a convolution, a concatenation of its two inputs,\n", " followed by two more convolutions with optional dropout\n", " Values:\n", " input_channels: the number of channels to expect from a given input\n", " '''\n", " def __init__(self, input_channels, use_dropout=False, use_bn=True):\n", " super(ExpandingBlock, self).__init__()\n", " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n", " self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)\n", " self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)\n", " self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)\n", " if use_bn:\n", " self.batchnorm = nn.BatchNorm2d(input_channels // 2)\n", " self.use_bn = use_bn\n", " self.activation = nn.ReLU()\n", " if use_dropout:\n", " self.dropout = nn.Dropout()\n", " self.use_dropout = use_dropout\n", "\n", " def forward(self, x, skip_con_x):\n", " '''\n", " Function for completing a forward pass of ExpandingBlock: \n", " Given an image tensor, completes an expanding block and returns the transformed tensor.\n", " Parameters:\n", " x: image tensor of shape (batch size, channels, height, width)\n", " skip_con_x: the image tensor from the contracting path (from the opposing block of x)\n", " for the skip connection\n", " '''\n", " x = self.upsample(x)\n", " x = self.conv1(x)\n", " skip_con_x = crop(skip_con_x, x.shape)\n", " x = torch.cat([x, skip_con_x], axis=1)\n", " x = self.conv2(x)\n", " if self.use_bn:\n", " x = self.batchnorm(x)\n", " if self.use_dropout:\n", " x = self.dropout(x)\n", " x = self.activation(x)\n", " x = self.conv3(x)\n", " if self.use_bn:\n", " x = self.batchnorm(x)\n", " if self.use_dropout:\n", " x = self.dropout(x)\n", " x = self.activation(x)\n", " return x\n", "\n", "class FeatureMapBlock(nn.Module):\n", " '''\n", " FeatureMapBlock Class\n", " The final layer of a U-Net - \n", " maps each pixel to a pixel with the correct number of output dimensions\n", " using a 1x1 convolution.\n", " Values:\n", " input_channels: the number of channels to expect from a given input\n", " output_channels: the number of channels to expect for a given output\n", " '''\n", " def __init__(self, input_channels, output_channels):\n", " super(FeatureMapBlock, self).__init__()\n", " self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)\n", "\n", " def forward(self, x):\n", " '''\n", " Function for completing a forward pass of FeatureMapBlock: \n", " Given an image tensor, returns it mapped to the desired number of channels.\n", " Parameters:\n", " x: image tensor of shape (batch size, channels, height, width)\n", " '''\n", " x = self.conv(x)\n", " return x\n", "\n", "class UNet(nn.Module):\n", " '''\n", " UNet Class\n", " A series of 4 contracting blocks followed by 4 expanding blocks to \n", " transform an input image into the corresponding paired image, with an upfeature\n", " layer at the start and a downfeature layer at the end.\n", " Values:\n", " input_channels: the number of channels to expect from a given input\n", " output_channels: the number of channels to expect for a given output\n", " '''\n", " def __init__(self, input_channels, output_channels, hidden_channels=32):\n", " super(UNet, self).__init__()\n", " self.upfeature = FeatureMapBlock(input_channels, hidden_channels)\n", " self.contract1 = ContractingBlock(hidden_channels, use_dropout=True)\n", " self.contract2 = ContractingBlock(hidden_channels * 2, use_dropout=True)\n", " self.contract3 = ContractingBlock(hidden_channels * 4, use_dropout=True)\n", " self.contract4 = ContractingBlock(hidden_channels * 8)\n", " self.contract5 = ContractingBlock(hidden_channels * 16)\n", " self.contract6 = ContractingBlock(hidden_channels * 32)\n", " self.expand0 = ExpandingBlock(hidden_channels * 64)\n", " self.expand1 = ExpandingBlock(hidden_channels * 32)\n", " self.expand2 = ExpandingBlock(hidden_channels * 16)\n", " self.expand3 = ExpandingBlock(hidden_channels * 8)\n", " self.expand4 = ExpandingBlock(hidden_channels * 4)\n", " self.expand5 = ExpandingBlock(hidden_channels * 2)\n", " self.downfeature = FeatureMapBlock(hidden_channels, output_channels)\n", " self.sigmoid = torch.nn.Sigmoid()\n", "\n", " def forward(self, x):\n", " '''\n", " Function for completing a forward pass of UNet: \n", " Given an image tensor, passes it through U-Net and returns the output.\n", " Parameters:\n", " x: image tensor of shape (batch size, channels, height, width)\n", " '''\n", " x0 = self.upfeature(x)\n", " x1 = self.contract1(x0)\n", " x2 = self.contract2(x1)\n", " x3 = self.contract3(x2)\n", " x4 = self.contract4(x3)\n", " x5 = self.contract5(x4)\n", " x6 = self.contract6(x5)\n", " x7 = self.expand0(x6, x5)\n", " x8 = self.expand1(x7, x4)\n", " x9 = self.expand2(x8, x3)\n", " x10 = self.expand3(x9, x2)\n", " x11 = self.expand4(x10, x1)\n", " x12 = self.expand5(x11, x0)\n", " xn = self.downfeature(x12)\n", " return self.sigmoid(xn)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "T6ndvjc_1KXx" }, "source": [ "## PatchGAN Discriminator\n", "\n", "Next, you will define a discriminator based on the contracting path of the U-Net to allow you to evaluate the realism of the generated images. Remember that the discriminator outputs a one-channel matrix of classifications instead of a single value. Your discriminator's final layer will simply map from the final number of hidden channels to a single prediction for every pixel of the layer before it." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "0nVuJPjV1f92" }, "outputs": [], "source": [ "# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n", "# GRADED CLASS: Discriminator\n", "class Discriminator(nn.Module):\n", " '''\n", " Discriminator Class\n", " Structured like the contracting path of the U-Net, the discriminator will\n", " output a matrix of values classifying corresponding portions of the image as real or fake. \n", " Parameters:\n", " input_channels: the number of image input channels\n", " hidden_channels: the initial number of discriminator convolutional filters\n", " '''\n", " def __init__(self, input_channels, hidden_channels=8):\n", " super(Discriminator, self).__init__()\n", " self.upfeature = FeatureMapBlock(input_channels, hidden_channels)\n", " self.contract1 = ContractingBlock(hidden_channels, use_bn=False)\n", " self.contract2 = ContractingBlock(hidden_channels * 2)\n", " self.contract3 = ContractingBlock(hidden_channels * 4)\n", " self.contract4 = ContractingBlock(hidden_channels * 8)\n", " #### START CODE HERE ####\n", " self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)\n", " #### END CODE HERE ####\n", "\n", " def forward(self, x, y):\n", " x = torch.cat([x, y], axis=1)\n", " x0 = self.upfeature(x)\n", " x1 = self.contract1(x0)\n", " x2 = self.contract2(x1)\n", " x3 = self.contract3(x2)\n", " x4 = self.contract4(x3)\n", " xn = self.final(x4)\n", " return xn" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "AFZBTJ_4Ubld" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Success!\n" ] } ], "source": [ "# UNIT TEST\n", "test_discriminator = Discriminator(10, 1)\n", "assert tuple(test_discriminator(\n", " torch.randn(1, 5, 256, 256), \n", " torch.randn(1, 5, 256, 256)\n", ").shape) == (1, 1, 16, 16)\n", "print(\"Success!\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qRk_8azSq3tF" }, "source": [ "## Training Preparation\n", "\n", "\n", "Now you can begin putting everything together for training. You start by defining some new parameters as well as the ones you are familiar with:\n", " * **real_dim**: the number of channels of the real image and the number expected in the output image\n", " * **adv_criterion**: an adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN\n", " * **recon_criterion**: a loss function that rewards similar images to the ground truth, which \"reconstruct\" the image\n", " * **lambda_recon**: a parameter for how heavily the reconstruction loss should be weighed\n", " * **n_epochs**: the number of times you iterate through the entire dataset when training\n", " * **input_dim**: the number of channels of the input image\n", " * **display_step**: how often to display/visualize the images\n", " * **batch_size**: the number of images per forward/backward pass\n", " * **lr**: the learning rate\n", " * **target_shape**: the size of the output image (in pixels)\n", " * **device**: the device type" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "UXptQZcwrBrq" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "# New parameters\n", "adv_criterion = nn.BCEWithLogitsLoss() \n", "recon_criterion = nn.L1Loss() \n", "lambda_recon = 200\n", "\n", "n_epochs = 20\n", "input_dim = 3\n", "real_dim = 3\n", "display_step = 200\n", "batch_size = 4\n", "lr = 0.0002\n", "target_shape = 256\n", "device = 'cuda'" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WPOUC6-nVDCv" }, "source": [ "You will then pre-process the images of the dataset to make sure they're all the same size and that the size change due to U-Net layers is accounted for. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "PNAK2XqMJ419" }, "outputs": [], "source": [ "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", "])\n", "\n", "import torchvision\n", "dataset = torchvision.datasets.ImageFolder(\"maps\", transform=transform)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "t7vKN1POUjud" }, "source": [ "Next, you can initialize your generator (U-Net) and discriminator, as well as their optimizers. Finally, you will also load your pre-trained model." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "vBY3Y9UrUgVX" }, "outputs": [], "source": [ "gen = UNet(input_dim, real_dim).to(device)\n", "gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)\n", "disc = Discriminator(input_dim + real_dim).to(device)\n", "disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)\n", "\n", "def weights_init(m):\n", " if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n", " torch.nn.init.normal_(m.weight, 0.0, 0.02)\n", " if isinstance(m, nn.BatchNorm2d):\n", " torch.nn.init.normal_(m.weight, 0.0, 0.02)\n", " torch.nn.init.constant_(m.bias, 0)\n", "\n", "# Feel free to change pretrained to False if you're training the model from scratch\n", "pretrained = True\n", "if pretrained:\n", " loaded_state = torch.load(\"pix2pix_15000.pth\")\n", " gen.load_state_dict(loaded_state[\"gen\"])\n", " gen_opt.load_state_dict(loaded_state[\"gen_opt\"])\n", " disc.load_state_dict(loaded_state[\"disc\"])\n", " disc_opt.load_state_dict(loaded_state[\"disc_opt\"])\n", "else:\n", " gen = gen.apply(weights_init)\n", " disc = disc.apply(weights_init)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YcpFbNDYzJrh" }, "source": [ "While there are some changes to the U-Net architecture for Pix2Pix, the most important distinguishing feature of Pix2Pix is its adversarial loss. You will be implementing that here!" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "YZE-Eyj0LOpm" }, "outputs": [], "source": [ "# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n", "# GRADED CLASS: get_gen_loss\n", "def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon):\n", " '''\n", " Return the loss of the generator given inputs.\n", " Parameters:\n", " gen: the generator; takes the condition and returns potential images\n", " disc: the discriminator; takes images and the condition and\n", " returns real/fake prediction matrices\n", " real: the real images (e.g. maps) to be used to evaluate the reconstruction\n", " condition: the source images (e.g. satellite imagery) which are used to produce the real images\n", " adv_criterion: the adversarial loss function; takes the discriminator \n", " predictions and the true labels and returns a adversarial \n", " loss (which you aim to minimize)\n", " recon_criterion: the reconstruction loss function; takes the generator \n", " outputs and the real images and returns a reconstructuion \n", " loss (which you aim to minimize)\n", " lambda_recon: the degree to which the reconstruction loss should be weighted in the sum\n", " '''\n", " # Steps: 1) Generate the fake images, based on the conditions.\n", " # 2) Evaluate the fake images and the condition with the discriminator.\n", " # 3) Calculate the adversarial and reconstruction losses.\n", " # 4) Add the two losses, weighting the reconstruction loss appropriately.\n", " #### START CODE HERE ####\n", " fake = gen(condition)\n", " disc_fake_hat = disc(fake, condition)\n", " gen_adv_loss = adv_criterion(disc_fake_hat, torch.ones_like(disc_fake_hat))\n", " gen_rec_loss = recon_criterion(real, fake)\n", " gen_loss = gen_adv_loss + lambda_recon * gen_rec_loss\n", " #### END CODE HERE ####\n", " return gen_loss" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "KLndbJ213hV5", "outputId": "a713cd00-2b7d-41da-a3cb-90f4dc0ca49f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Success!\n" ] } ], "source": [ "# UNIT TEST\n", "def test_gen_reasonable(num_images=10):\n", " gen = torch.zeros_like\n", " disc = lambda x, y: torch.ones(len(x), 1)\n", " real = None\n", " condition = torch.ones(num_images, 3, 10, 10)\n", " adv_criterion = torch.mul\n", " recon_criterion = lambda x, y: torch.tensor(0)\n", " lambda_recon = 0\n", " assert get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon).sum() == num_images\n", "\n", " disc = lambda x, y: torch.zeros(len(x), 1)\n", " assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)).sum() == 0\n", "\n", " adv_criterion = lambda x, y: torch.tensor(0)\n", " recon_criterion = lambda x, y: torch.abs(x - y).max()\n", " real = torch.randn(num_images, 3, 10, 10)\n", " lambda_recon = 2\n", " gen = lambda x: real + 1\n", " assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon) - 2) < 1e-4\n", "\n", " adv_criterion = lambda x, y: (x + y).max() + x.max()\n", " assert torch.abs(get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon) - 3) < 1e-4\n", "test_gen_reasonable()\n", "print(\"Success!\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "SMDZWZTz3ivA" }, "source": [ "## Pix2Pix Training\n", "\n", "Finally, you can train the model and see some of your maps!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 373, "referenced_widgets": [ "aa7565ec3f294fd6b9c592bd5fc0dfcb", "fe98210470c3421c9a39734dc1203817", "a47c53e27edd4ef2b5bc79b3d64c44d3", "54a72fb618d146babc5830644ff65992", "6bf1b15c1b8e42758c8e9768115d6f8e", "bc3d591d82414f86888f2512bb3eb02a", "8cdec6ea735847709cc610fef8dc5755", "5042b39eadc14d5ab8b310e23d9c7d96" ] }, "colab_type": "code", "id": "fy6UBV60HtnY", "outputId": "c174bb25-acbf-4507-c6e2-6bf7ef08661c" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5fea174d4b4e4fa3b5af867c6ecfcab4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=549.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Pretrained initial state\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 62\u001b[0m }, f\"pix2pix_{cur_step}.pth\")\n\u001b[1;32m 63\u001b[0m \u001b[0mcur_step\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 64\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(save_model)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# Dataloader returns the batches\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mimage_width\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mcondition\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mimage_width\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/notebook.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;31m# return super(tqdm...) will not catch exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/std.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1102\u001b[0m fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m 1103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1104\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1105\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 344\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 346\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 385\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 386\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0msample\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0msample\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \"\"\"\n\u001b[0;32m--> 101\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 101\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "from skimage import color\n", "import numpy as np\n", "\n", "def train(save_model=False):\n", " mean_generator_loss = 0\n", " mean_discriminator_loss = 0\n", " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", " cur_step = 0\n", "\n", " for epoch in range(n_epochs):\n", " # Dataloader returns the batches\n", " for image, _ in tqdm(dataloader):\n", " image_width = image.shape[3]\n", " condition = image[:, :, :, :image_width // 2]\n", " condition = nn.functional.interpolate(condition, size=target_shape)\n", " real = image[:, :, :, image_width // 2:]\n", " real = nn.functional.interpolate(real, size=target_shape)\n", " cur_batch_size = len(condition)\n", " condition = condition.to(device)\n", " real = real.to(device)\n", "\n", " ### Update discriminator ###\n", " disc_opt.zero_grad() # Zero out the gradient before backpropagation\n", " with torch.no_grad():\n", " fake = gen(condition)\n", " disc_fake_hat = disc(fake.detach(), condition) # Detach generator\n", " disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))\n", " disc_real_hat = disc(real, condition)\n", " disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))\n", " disc_loss = (disc_fake_loss + disc_real_loss) / 2\n", " disc_loss.backward(retain_graph=True) # Update gradients\n", " disc_opt.step() # Update optimizer\n", "\n", " ### Update generator ###\n", " gen_opt.zero_grad()\n", " gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)\n", " gen_loss.backward() # Update gradients\n", " gen_opt.step() # Update optimizer\n", "\n", " # Keep track of the average discriminator loss\n", " mean_discriminator_loss += disc_loss.item() / display_step\n", " # Keep track of the average generator loss\n", " mean_generator_loss += gen_loss.item() / display_step\n", "\n", " ### Visualization code ###\n", " if cur_step % display_step == 0:\n", " if cur_step > 0:\n", " print(f\"Epoch {epoch}: Step {cur_step}: Generator (U-Net) loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}\")\n", " else:\n", " print(\"Pretrained initial state\")\n", " show_tensor_images(condition, size=(input_dim, target_shape, target_shape))\n", " show_tensor_images(real, size=(real_dim, target_shape, target_shape))\n", " show_tensor_images(fake, size=(real_dim, target_shape, target_shape))\n", " mean_generator_loss = 0\n", " mean_discriminator_loss = 0\n", " # You can change save_model to True if you'd like to save the model\n", " if save_model:\n", " torch.save({'gen': gen.state_dict(),\n", " 'gen_opt': gen_opt.state_dict(),\n", " 'disc': disc.state_dict(),\n", " 'disc_opt': disc_opt.state_dict()\n", " }, f\"pix2pix_{cur_step}.pth\")\n", " cur_step += 1\n", "train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "C3W2: Pix2Pix.ipynb", "provenance": [] }, "coursera": { "schema_names": [ "GANSC3-2B" ] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "5042b39eadc14d5ab8b310e23d9c7d96": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "54a72fb618d146babc5830644ff65992": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5042b39eadc14d5ab8b310e23d9c7d96", "placeholder": "​", "style": "IPY_MODEL_8cdec6ea735847709cc610fef8dc5755", "value": " 23/275 [00:15<02:49, 1.49it/s]" } }, "6bf1b15c1b8e42758c8e9768115d6f8e": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "8cdec6ea735847709cc610fef8dc5755": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "a47c53e27edd4ef2b5bc79b3d64c44d3": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "danger", "description": " 8%", "description_tooltip": null, "layout": "IPY_MODEL_bc3d591d82414f86888f2512bb3eb02a", "max": 275, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_6bf1b15c1b8e42758c8e9768115d6f8e", "value": 23 } }, "aa7565ec3f294fd6b9c592bd5fc0dfcb": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a47c53e27edd4ef2b5bc79b3d64c44d3", "IPY_MODEL_54a72fb618d146babc5830644ff65992" ], "layout": "IPY_MODEL_fe98210470c3421c9a39734dc1203817" } }, "bc3d591d82414f86888f2512bb3eb02a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fe98210470c3421c9a39734dc1203817": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } } } } }, "nbformat": 4, "nbformat_minor": 1 }