{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# \"Triplet Training for Generative Adversarial Networks\"\n", "- toc: true\n", "- categories: [\"deep-learning\"]\n", "- image: images/copied_from_nb/images/gan/gan-triplet.png\n", "- comments: true\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What will I get from this blog?\n", "\n", "By the end of this blog we'll be able to make a model that can generate the below numers from random noise\n", "\n", "![](images/gan/gan-diagram.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "> This is an implementation of the paper: Zieba, Maciej, and Lei Wang. \"Training triplet networks with gan.\" arXiv preprint arXiv:1704.02227 (2017). \n", "\n", "\n", "### What is triplet training?\n", "Triplet training helps us learn distributed embeddings by the notion of similarity and dissimilarity. Read more about it [here](https://towardsdatascience.com/image-similarity-using-triplet-loss-3744c0f67973)\n", "\n", "This paper replaces triplet loss with the classification losss of the discriminator and compares the results.\n", "\n", "## Approach \n", "\n", "I initially pretrained the GAN with the original GAN objective, for 50 epochs. Post that I train with the Improved GAN objective for 10 epochs. \n", "\n", "The discriminator of the GAN generates M features, we take these M features and put it into the classifier. For classification I have used a 9-nearest neighbour classifier." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "IKR-DnPTXcX7", "outputId": "5f495e89-55bf-46a5-b19e-e950398c1a15" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Archive: results.zip\n", " inflating: models/D_16_100.pkl \n", " inflating: models/D_16_200.pkl \n", " inflating: models/D_32_100.pkl \n", " inflating: models/G_16_100.pkl \n", " inflating: models/G_16_200.pkl \n", " inflating: models/G_32_100.pkl \n", " inflating: models/pretrain_D_16.pkl \n", " inflating: models/pretrain_D_32.pkl \n", " inflating: models/pretrain_G_16.pkl \n", " inflating: models/pretrain_G_32.pkl \n" ] } ], "source": [ "#hide\n", "# Pretrained models. download them later if you want to compare the results\n", "!wget -N --quiet --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1ufF0r_fs64wjCHITE7GsGdAh7BKMNfTk' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=1ufF0r_fs64wjCHITE7GsGdAh7BKMNfTk\" -O results.zip && rm -rf /tmp/cookies.txt > /tmp/xxy\n", "!unzip -o results.zip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Y4VmTy46XcX_" }, "outputs": [], "source": [ "#collapse-hide\n", "# Reference\n", "# https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier.score\n", "# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html\n", "# https://github.com/andreasveit/triplet-network-pytorch\n", "!pip install livelossplot > /tmp/xxy\n", "\n", "import math\n", "import os\n", "import sys\n", "from pathlib import Path\n", "from pprint import pprint\n", "\n", "import numpy as np\n", "from PIL import Image\n", "from scipy.spatial.distance import cdist\n", "# from torch import cdist\n", "from sklearn import preprocessing\n", "from sklearn.metrics import average_precision_score\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from tqdm.auto import tqdm as tq\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from livelossplot import PlotLosses\n", "from torch.nn import functional as F\n", "from torch.nn.parameter import Parameter\n", "from torch.utils.data import DataLoader, Dataset, TensorDataset\n", "from torchvision import datasets, transforms\n", "from torchvision.utils import save_image\n", "import matplotlib.pyplot as plt\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "hvHUT5S6XcYC" }, "outputs": [], "source": [ "#collapse-hide\n", "# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py\n", "# https://www.kaggle.com/hirotaka0122/triplet-loss-with-pytorch\n", "# https://github.com/openai/improved-gan/blob/master/mnist_svhn_cifar10/train_mnist_feature_matching.py\n", "# https://github.com/adambielski/siamese-triplet/blob/master/Experiments_MNIST.ipynb\n", "# https://github.com/eladhoffer/TripletNet\n", "# https://stackoverflow.com/questions/26210471/scikit-learn-gridsearch-giving-valueerror-multiclass-format-is-not-supported" ] }, { "cell_type": "markdown", "metadata": { "id": "Zmp93XL5XcYF" }, "source": [ "## Utils\n", "\n", "I combined all the handy functions used throughout the notebook into one place, for easier access." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "gdwBALC5XcYG" }, "outputs": [], "source": [ "#collapse-hide\n", "class Arguments():\n", " def __init__(self):\n", " self.batch_size=100\n", " self.epochs=10\n", " self.lr=0.003\n", " self.momentum=0.5\n", " self.cuda=torch.cuda.is_available()\n", " self.seed=1\n", " self.log_interval=100\n", " self.save_interval=5\n", " self.unlabel_weight=1\n", " self.logdir='./logfile'\n", " self.savedir='./models'\n", " self.load_saved=True\n", "\n", "args = Arguments()\n", "np.random.seed(args.seed)\n", "torch.manual_seed(args.seed)\n", "\n", "results = {}\n", "def log_sum_exp(x, axis = 1):\n", " m = torch.max(x, dim = 1)[0]\n", " return m + torch.log(torch.sum(torch.exp(x - m.unsqueeze(1)), dim = axis))\n", "\n", "def reset_normal_param(L, stdv, weight_scale = 1.):\n", " assert type(L) == torch.nn.Linear\n", " torch.nn.init.normal(L.weight, std=weight_scale / math.sqrt(L.weight.size()[0]))\n", "\n", "def show_gen_images(gan):\n", " num_images=4\n", " arr = gan.draw(num_images)\n", " square_dim = num_images//2\n", " f, axarr = plt.subplots(square_dim,square_dim)\n", " # f.set_figheight(10)\n", " # f.set_figwidth(10)\n", " for i in range(square_dim):\n", " for j in range(square_dim):\n", " axarr[i,j].imshow(arr[i*square_dim+j],cmap='gray')\n", " axarr[i,j].axis('off')\n", "\n", "# https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/master/functional.py#L13\n", "class LinearWeightNorm(torch.nn.Module):\n", " def __init__(self, in_features, out_features, bias=True, weight_scale=None, weight_init_stdv=0.1):\n", " super(LinearWeightNorm, self).__init__()\n", " self.in_features = in_features\n", " self.out_features = out_features\n", " self.weight = Parameter(torch.randn(out_features, in_features) * weight_init_stdv)\n", " if bias:\n", " self.bias = Parameter(torch.zeros(out_features))\n", " else:\n", " self.register_parameter('bias', None)\n", " if weight_scale is not None:\n", " assert type(weight_scale) == int\n", " self.weight_scale = Parameter(torch.ones(out_features, 1) * weight_scale)\n", " else:\n", " self.weight_scale = 1\n", " def forward(self, x):\n", " W = self.weight * self.weight_scale / torch.sqrt(torch.sum(self.weight ** 2, dim = 1, keepdim = True))\n", " return F.linear(x, W, self.bias)\n", " def __repr__(self):\n", " return self.__class__.__name__ + '(' \\\n", " + 'in_features=' + str(self.in_features) \\\n", " + ', out_features=' + str(self.out_features) \\\n", " + ', weight_scale=' + str(self.weight_scale) + ')'" ] }, { "cell_type": "markdown", "metadata": { "id": "ghGmjdNJXcYI" }, "source": [ "## Datasets\n", "\n", "Below is the MNIST dataset. It has 60,000 training and 10,000 testing samples. The labelled set has N (100 or 200) samples and the unlabelled set had all 60,000 samples. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "RZCHKj8-XcYJ", "outputId": "d50ba8d1-1476-4ee9-cee3-ee977f8ecb41" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8f966c3006734a2aa86cf9960a3f3fad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ec7c7b1b5bdf45ae96b43644309a7e94", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da991ed9f2cd44ffa15613ca4eccd0f0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n", "\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "50ea4da8a1a143ad8e5d7f115d8a48ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n", "Processing...\n", "Done!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/conda-bld/pytorch_1587428398394/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n" ] }, { "data": { "text/plain": [ "Dataset MNIST\n", " Number of datapoints: 10000\n", " Root location: ../data\n", " Split: Test\n", " StandardTransform\n", "Transform: Compose(\n", " ToTensor()\n", " )" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#collapse-hide\n", "# Reference https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/master/Datasets.py\n", "def MNISTLabel(class_num):\n", " raw_dataset = datasets.MNIST('../data', train=True, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " ]))\n", " class_tot = [0] * 10\n", " data = []\n", " labels = []\n", " tot = 0\n", " perm = np.random.permutation(raw_dataset.__len__())\n", " for i in range(raw_dataset.__len__()):\n", " datum, label = raw_dataset.__getitem__(perm[i])\n", " if class_tot[label] < class_num:\n", " data.append(datum.numpy())\n", " labels.append(label)\n", " class_tot[label] += 1\n", " tot += 1\n", " if tot >= 10 * class_num:\n", " break\n", " \n", " times = int(np.ceil(60_000 / len(data)))\n", " return TensorDataset(torch.FloatTensor(np.array(data)).repeat(times,1,1,1), torch.LongTensor(np.array(labels)).repeat(times))\n", "\n", "def MNISTUnlabel():\n", " raw_dataset = datasets.MNIST('../data', train=True, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " ]))\n", " return raw_dataset\n", "def MNISTTest():\n", " return datasets.MNIST('../data', train=False, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " ]))\n", "# Reference https://github.com/adambielski/siamese-triplet/blob/master/datasets.py#L79\n", "class MNISTTriplet(Dataset):\n", " def __init__(self, mnist_dataset):\n", " self.mnist_dataset = mnist_dataset.tensors\n", " self.train_labels = self.mnist_dataset[1]\n", " self.train_data = self.mnist_dataset[0]\n", " self.labels_set = set(self.train_labels.numpy())\n", " self.label_to_indices = {}\n", " for label in self.labels_set:\n", " self.label_to_indices[label] = np.where(self.train_labels.numpy() == label)[0]\n", "\n", " def __getitem__(self, index):\n", " img1, label1 = self.train_data[index], self.train_labels[index].item()\n", " positive_index = index\n", " while positive_index == index:\n", " positive_index = np.random.choice(self.label_to_indices[label1])\n", " negative_label = np.random.choice(list(self.labels_set - set([label1])))\n", " negative_index = np.random.choice(self.label_to_indices[negative_label])\n", " img3 = self.train_data[negative_index]\n", " img2 = self.train_data[positive_index]\n", " return img1, img2, img3\n", "\n", " def __len__(self):\n", " return len(self.mnist_dataset[1])\n", "# Reference https://github.com/adambielski/siamese-triplet/blob/master/losses.py\n", "class TripletLoss(nn.Module):\n", " def __init__(self):\n", " super(TripletLoss, self).__init__()\n", "\n", " def forward(self, anchor, positive, negative, size_average=True):\n", " d_positive = torch.sqrt(torch.sum((anchor - positive).pow(2),axis=1))\n", " d_negative = torch.sqrt(torch.sum((anchor - negative).pow(2),axis=1))\n", " z = torch.cat((d_positive.unsqueeze(1),d_negative.unsqueeze(1)),axis=1)\n", " z = log_sum_exp(z)\n", " return -torch.mean(d_negative) + torch.mean(z)\n", "\n", "MNISTUnlabel()\n", "MNISTTest()" ] }, { "cell_type": "markdown", "metadata": { "id": "fe8FDZbJXcYP" }, "source": [ "## GAN" ] }, { "cell_type": "markdown", "metadata": { "id": "bKsaR_b7XcYQ" }, "source": [ "### Architecture: Generator and Discriminator\n", "\n", "This notebook follows the architecture used in Improved GAN (Salimans et al. 2016). The discriminator outputs M (16 or 32) features. The hyperparameters are same as in TripletGAN. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "FWEjCPFiXcYQ" }, "outputs": [], "source": [ "#collapse-hide\n", "# https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/master/Nets.py\n", "class Discriminator(nn.Module):\n", " def __init__(self, input_dim = 28 ** 2, output_dim = 10):\n", " super(Discriminator, self).__init__()\n", " self.input_dim = input_dim\n", " self.output_dim = output_dim\n", " self.layers = torch.nn.ModuleList([\n", " LinearWeightNorm(input_dim, 1000),\n", " LinearWeightNorm(1000, 500),\n", " LinearWeightNorm(500, 250),\n", " LinearWeightNorm(250, 250),\n", " LinearWeightNorm(250, 250)]\n", " )\n", " self.final = LinearWeightNorm(250, output_dim, weight_scale=1)\n", " self.reduce = nn.Sequential(\n", " nn.Linear(self.output_dim, 1),\n", " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x, feature = False, pretrain=False):\n", " x = x.view(-1, self.input_dim)\n", " noise = torch.randn(x.size()) * 0.3 if self.training else torch.Tensor([0])\n", " if args.cuda:\n", " noise = noise.cuda()\n", " x = x + noise\n", " for i in range(len(self.layers)):\n", " m = self.layers[i]\n", " x_f = F.relu(m(x))\n", " noise = torch.randn(x_f.size()) * 0.5 if self.training else torch.Tensor([0])\n", " if args.cuda:\n", " noise = noise.cuda()\n", " x = (x_f + noise)\n", " if feature:\n", " return x_f\n", " out = self.final(x)\n", " if pretrain:\n", " out = self.reduce(out)\n", " return out\n", "\n", "\n", "class Generator(nn.Module):\n", " def __init__(self, z_dim, output_dim = 28 * 28):\n", " super(Generator, self).__init__()\n", " self.z_dim = z_dim\n", " self.fc1 = nn.Linear(z_dim, 500, bias = False)\n", " self.bn1 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)\n", " self.fc2 = nn.Linear(500, 500, bias = False)\n", " self.bn2 = nn.BatchNorm1d(500, affine = False, eps=1e-6, momentum = 0.5)\n", " self.fc3 = LinearWeightNorm(500, output_dim, weight_scale = 1)\n", " self.bn1_b = Parameter(torch.zeros(500))\n", " self.bn2_b = Parameter(torch.zeros(500))\n", " nn.init.xavier_uniform_(self.fc1.weight)\n", " nn.init.xavier_uniform_(self.fc2.weight)\n", "\n", " def forward(self, batch_size, draw=None):\n", " if draw is None:\n", " x = torch.rand(batch_size, self.z_dim)\n", " else:\n", " x = draw\n", " if args.cuda:\n", " x = x.cuda()\n", " x = F.softplus(self.bn1(self.fc1(x)) + self.bn1_b)\n", " x = F.softplus(self.bn2(self.fc2(x)) + self.bn2_b)\n", " x = F.softplus(self.fc3(x))\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "id": "ValQ0-UuXcYU" }, "source": [ "## GAN Class\n", "\n", "For pretraining the GAN, I have followed the standard objective (Goodfellow, Ian, et al. 2014) \n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "7SCqOSAdXcYV" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "#collapse-hide\n", "# Reference: https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/master/ImprovedGAN.py\n", "class ImprovedGAN(object):\n", " def __init__(self, G, D, labeled, unlabeled, test):\n", " self.G = G\n", " self.D = D\n", " # if(args.mode == 'train'):\n", " g_name = 'G_'+str(D.output_dim)+'_'+str(args.labeled)+'.pkl'\n", " d_name = 'D_'+str(D.output_dim)+'_'+str(args.labeled)+'.pkl'\n", " # else:\n", " g_name_pretrain = 'pretrain' + '_G_'+str(D.output_dim)+'.pkl'\n", " d_name_pretrain = 'pretrain' + '_D_'+str(D.output_dim)+'.pkl'\n", "\n", " if args.mode == 'pretrain':\n", " self.g_path = Path(args.savedir) / g_name_pretrain\n", " self.d_path = Path(args.savedir) / d_name_pretrain\n", " else:\n", " self.g_path = Path(args.savedir) / g_name\n", " self.d_path = Path(args.savedir) / d_name\n", "\n", " if os.path.exists(args.savedir) and args.load_saved:\n", " print('Loading model ' + args.savedir)\n", " if False and os.path.exists(self.g_path):\n", " self.G.load_state_dict(torch.load(self.g_path))\n", " self.D.load_state_dict(torch.load(self.d_path))\n", " else:\n", " print('Loaded pretrain')\n", " self.G.load_state_dict(torch.load(Path(args.savedir) / g_name_pretrain))\n", " self.D.load_state_dict(torch.load(Path(args.savedir) / d_name_pretrain))\n", " else:\n", " print('Creating model')\n", " if not os.path.exists(args.savedir):\n", " os.makedirs(args.savedir)\n", " torch.save(self.G.state_dict(), self.g_path)\n", " torch.save(self.D.state_dict(), self.d_path)\n", " # self.writer = tensorboardX.SummaryWriter(log_dir=args.logdir)\n", " if args.cuda:\n", " self.G.cuda()\n", " self.D.cuda()\n", " \n", " self.Doptim = optim.Adam(self.D.parameters(), lr=args.lr, betas= (args.momentum, 0.999))\n", " self.Goptim = optim.Adam(self.G.parameters(), lr=args.lr, betas = (args.momentum,0.999))\n", " self.knn = KNeighborsClassifier(n_neighbors=9)\n", " self.tripletloss = TripletLoss()\n", " self.drawnoise = torch.rand(4, self.G.z_dim)\n", " self.labeled = labeled\n", " self.unlabeled = unlabeled\n", " self.test = test\n", "\n", " def get_features(self,dataset):\n", " loader = DataLoader(dataset, batch_size = args.batch_size, shuffle=True, drop_last=True, num_workers = 4)\n", " X = []\n", " y = []\n", " for (data,label) in loader:\n", " data = data.cuda()\n", " X += self.D(data)\n", " y += label\n", " # del data\n", " del loader,data\n", " X = torch.stack(X).data.cpu().numpy()\n", " y = torch.LongTensor(y).data.cpu().numpy()\n", " # X = torch.stack(X)\n", " # y = torch.LongTensor(y)\n", " return X,y\n", "\n", " def trainknn(self):\n", " X,y = self.get_features(self.unlabeled)\n", " self.knn.fit(X,y)\n", " print(\"Fit done\")\n", " del X,y\n", " \n", " def calc_mAP(self,test_features, testy, train_features, trainy):\n", " Y = cdist(test_features,train_features)\n", " ind = np.argsort(Y,axis=1)\n", " print(\"Done argsort\")\n", " del Y,train_features\n", " prec = 0.0\n", " num_classes = 10\n", " acc = [0.0] * num_classes\n", " test_len = len(test_features)\n", " # print(\"testlen\",test_len)\n", " for k in range(test_len):\n", " class_values = trainy[ind[k,:]]\n", " y_true = (testy[k] == class_values)\n", " # print(\"ylen\",y_true.shape[0])\n", " y_scores = np.arange(y_true.shape[0],0,-1)\n", " prec += average_precision_score(y_true, y_scores)\n", "\n", " for n in range(num_classes):\n", " a = class_values[0:(n+1)]\n", " counts = np.bincount(a)\n", " b = np.where(counts==np.max(counts))[0]\n", " if testy[k] in b:\n", " acc[n] = acc[n] + (1.0/float(len(b)))\n", " prec = prec/float(test_len)\n", " acc= [x / float(test_len) for x in acc]\n", " del ind,class_values,y_true,y_scores\n", " return np.mean(acc)*100,prec\n", "\n", " def evalknn(self,results):\n", " test_features, testy = self.get_features(self.test)\n", " train_features, trainy = self.get_features(self.unlabeled)\n", " accuracy,mAP = self.calc_mAP(test_features, testy, train_features, trainy)\n", " del test_features,testy,train_features,trainy\n", " results[args.mode+'_'+str(args.features)+'_'+str(args.labeled)] = [accuracy,mAP]\n", " return accuracy,mAP\n", " \n", " def draw(self, batch_size):\n", " self.G.eval()\n", " return self.G(batch_size,draw=self.drawnoise).view((batch_size,28,28)).data.cpu().numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "IO_geDVUXcYY" }, "source": [ "### Pretrain GAN\n", "\n", "We'll first pretrain the GAN for 50 epochs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Original GAN objective" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "5LBVmBOHXcYY" }, "outputs": [], "source": [ "#collapse-hide\n", "def pretrain(self):\n", " # Reference: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py\n", " plotlosses = PlotLosses(groups={'loss': ['generator', 'discriminator']})\n", " # Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor\n", " bce_loss = torch.nn.BCELoss()\n", " if args.cuda:\n", " bce_loss.cuda()\n", "\n", " dataloader = DataLoader(self.unlabeled, batch_size = args.batch_size, shuffle=True, drop_last=True, num_workers = 4)\n", " for epoch in tq(range(args.epochs)):\n", " losses = {'discriminator':0,'generator':0}\n", " for i, (imgs, _) in enumerate(dataloader):\n", " valid = torch.ones((imgs.size(0), 1))\n", " fake = torch.zeros((imgs.size(0), 1))\n", "\n", " train_imgs = imgs\n", "\n", " generated_images = self.G(args.batch_size)\n", " \n", " if args.cuda:\n", " valid, fake, train_imgs, generated_images = valid.cuda(), fake.cuda(), train_imgs.cuda(), generated_images.cuda()\n", "\n", " generator_loss = bce_loss(self.D(generated_images,pretrain=True), valid)\n", " \n", " self.Goptim.zero_grad()\n", " generator_loss.backward()\n", " self.Goptim.step()\n", "\n", " real_loss = bce_loss(self.D(train_imgs,pretrain=True), valid)\n", " fake_loss = bce_loss(self.D(generated_images.detach(),pretrain=True), fake)\n", " \n", " discriminator_loss = (fake_loss + real_loss) / 2\n", "\n", " self.Doptim.zero_grad()\n", " discriminator_loss.backward()\n", " self.Doptim.step()\n", "\n", " losses['generator'] += generator_loss.item()\n", " losses['discriminator'] += discriminator_loss.item()\n", " \n", " num_batches = len(self.unlabeled) / args.batch_size\n", " for key in losses:\n", " losses[key] /= num_batches\n", " \n", " plotlosses.update(losses)\n", " plotlosses.send()\n", " \n", " if (epoch + 1) % args.save_interval == 0:\n", " torch.save(self.G.state_dict(), self.g_path)\n", " torch.save(self.D.state_dict(), self.d_path)\n", "\n", "ImprovedGAN.pretrain = pretrainargs.load_saved=False\n", "args.epochs=50\n", "args.mode = 'pretrain'\n", "\n", "# m=16 n=100\n", "args.labeled=100\n", "args.features=16\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTLabel(args.labeled/10), MNISTUnlabel(), MNISTTest())\n", "gan.pretrain()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "6xUFQiiRXcYa" }, "source": [ "#### m=16\n", "Here we pretrain a model that would generate 16 features given an input image. To generate the images, I created an array with noise and sent it to the generator to get the image.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "RxNHZS3cXcYb", "outputId": "6ba1db35-4eb4-44f7-f6ea-d1f41ab48bb9", "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loss\n", "\tgenerator \t (min: 1.549, max: 2.075, cur: 1.775)\n", "\tdiscriminator \t (min: 0.391, max: 0.429, cur: 0.409)\n", "\n", "Done w cdist\n", "(30.023548015873082, 0.15303857492111891)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAADnCAYAAABcxZBBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUDUlEQVR4nO3dWYyVRdPA8T6yg+yLyCoIBmVTQBYJi0SWGDFGTMygGSMYUeONN8KFXgB6iShioiJREA0xmKAgKFEB2RTFwADKIgwwyDLsMuzLfBdf0laVzCPMe5Y64/93VU8KhmN8ptLdp7s6VV5eHgDAm5ty/QEA4FooTgBcojgBcIniBMAlihMAl6onJVOpFF/lOVFeXp7K9WeoSni3/ajo3WbkBMAlihMAlyhOAFyiOAFwieIEwCWKEwCXKE4AXKI4AXCJ4gTAJYoTAJcSj694lkr9veP9ppt0je3cuXOMS0pKVK56df2f3L59+xhv375d5WrXrh3jv/76S+Vo0gdkFiMnAC5RnAC4lEqanmT75PZtt92mng8ePBjjli1bqtykSZNivGvXLpU7c+ZMjMeOHatyTZs2Vc9Tp06N8eXLl1Vu6dKlMW7VqpXK7dmzp8K/J6ecV65cCelAV4L0oiuBH3QlAJBXKE4AXKI4AXDJ1VaCY8eOqeeaNWvG+PTp0yrXo0ePGH/99dcq17Nnzxhv3rxZ5U6dOqWejx8/HmO7riTXwIqLi1VOrjNdvXo1AKhYtWrVYny967CMnAC4RHEC4FLWtxLY3dxySlSnTh2Va9GiRYz79++vcnJnt92CMGHChBiPHz9e5bZt26aemzVrFmM7ratfv36MV65cqXJbtmyJ8YULF1QuE7vH2UqQXtneSiC3l4SQ/RMG9t+Xz7lelmArAYC8QnEC4BLFCYBLOT++UqNGjRh36NBB5Vq3bh3j0aNHq9zJkydjPHfuXJU7d+7cNX9+CP9c82rXrl2Mu3fvrnLDhg2L8YEDB1Ru+vTpMd6/f7/K2eMs6cCaU3pl492WHTDk+mUI+qt1+zso39+LFy+qXK1atWJ88803V/gzQwihbt26FeYaNWoUY7tmumPHjhjbd9l+nnRgzQlAXqE4AXAp5zvEk07w33HHHTF+6aWXVG748OExPnv2rMrJXeB2GGqHt0ePHr1mHEIIffr0iXGnTp1UTm5fkFPMaz0jv8npmZ3myFybNm1UbuTIkTHeuHFjhX/Pds54++23Yyy3uoQQwsCBA2P81VdfqVxhYaF6ltNDOY0LIYS1a9fG2P7+NG/ePMa244c8KZHp7RCMnAC4RHEC4BLFCYBLWVlzkutKdp566dKlGMsOASGE0K1btxg/99xzKifnzOfPn1e5pO34ds1APtuOBQsXLozx3XffrXLy+MyhQ4dUrrJrTnKbQ66PFOBv8h1JOoYiL8QIQb8/9ojVrFmzYmzXjuTalVxbDUEfsbK/L5Z8h+677z6Vk+/a6tWrVa5v374xtp1jd+/enfhvphMjJwAuUZwAuJSVaV3SV47yK1W5ozWEEJYsWRJj+3WnnA6mawp05MgR9SyH4var2N9++y3G9m68ymIq55OcAtlp3eDBg2N84sQJlZs4cWKM7QmD7777LsZ2yUBuAfj9999V7vnnn4/xk08+qXKzZ89Wz/LffPnll1Xu2WefjbE9RfHJJ5/E2DZ5zCZGTgBcojgBcIniBMClrKw5JX1FLi8xaNiwocp17do1xl9++aXKyW39hw8fVrkbWbuRn8124mzcuHGM5VGaEPQln99++22FP5N1pPyX9P9w3bp1MbbrksuXL4+x7OoaQggFBQUxXrBggcr169cvxvbSjzFjxsTYrgeVlpaqZ9lxQ77LIYTQpUuXGNvtAbLLhj2+kk2MnAC4RHEC4FJWpnVyWGy/ij1z5kyM//zzT5WTTa/sLnDZbUBuRwhBbzOwzeVsVwI5rZSnyO3nsbvH5de9dqtEgwYNYkyHgqpNNmqzHQSSpnWykeLSpUtVTjZPtEsWSY0M7bstt8bMmTNH5b7//vsY220O8q5HO6VNOu2RboycALhEcQLgEsUJgEtZ74Rp56lyTcg2Wi8qKopx0hZ/uW4Vgt6SYL8KtXN/ebLbrmt17NgxxrfffrvKvfXWWzG2XyHbYzCouuT7LI80haA7StqjWZ9++mmM7bqOXAO6kXUd+2flv7l+/XqVGzFiRIxtB1i5Tmq702YTIycALlGcALiU8wsO5FDUTuvkKe89e/aonPy6fty4cSonuwR8/vnnKmef5b+/aNEilZPD60GDBqncG2+8EeMXX3xR5TJxbx38s9MzOSWynSvkaQS7C7yyX9HbbTMvvPBCjO0JB7m7fM2aNSontyQwrQMAg+IEwCWKEwCXEtec7Hb4TMw/5fzaHm2R8+I//vhD5eSparmtIAT9Far8yjaEEDZs2KCeV61aFWN7cabcWmAvRZQnwOX6Vwh0IsD/k0ee6tWrp3JlZWUxruzvlf19sZd6ygswJ02apHLvvPNOjO1FCfYi2lxh5ATAJYoTAJcoTgBcSlxzyvYeB7u/I2kP1Pbt22M8ZcoUlZMXD8q5fQj/vCFD/lzbWVDO4fv06aNyn332WYztHhJ5ZMb+TLkXhf1QVYtdo5X/r+3tQUmS2pLInynb/YQQwj333KOe5bEqedtLCCH8+OOPMbYtU7xg5ATAJYoTAJdSSVvlU6nUde+jl8PNpOmZJbtY2mmOHN7abpfyIkB5+jsEfapadsUM4Z9fk9aqVeua/14IenrYtm1blZPdDeRlByGE0L179xiPHz9e5Sp7NKG8vDz1738K1+tG3u3Kku+ofbfl+2wvtZR/1nYzkNtUbKeOnj17qmd55GrlypUqJ5+3bt2qcpnucGlV9G4zcgLgEsUJgEsUJwAupa1lipyn2tYNck3G3kYi59B27i3Xi+yREHkcwHaelFsg7Netlszfe++9Kie3IdgjKrLrYZMmTVROtlOx/01ejgYg8673FiC71tmtW7cY22Nb9evXj7G9BNb+2TvvvDPG9lJNeZFmtteYrhcjJwAuUZwAuJSRaZ0cloagO/117dq1wp9hp2BJO63lpQZ2WCyHzHIYHIK+/CAE3YnA7pQtLCys8N+XO9R/+eUXlZOXIdivaeVXwzeyaxj5zU6r5EkFu7N806ZNMbZbCeSJBjtVbN++vXqWyykzZsxQuUycTpCfJx2dORg5AXCJ4gTAJYoTAJfStuaUtM4jO1WOHDlS5eQa0F133aVyH3/8cYztxX+SXQ8aMGBAjG13QPtztm3bFmO7HibXCd58802Vk2sB9mYY221TkutMdq3M61e6+N/ZG1Ykuz4j3wu7Likvem3Xrp3KDRkyRD1/+OGHMT58+LDK2WNd6ZDuDrCMnAC4RHEC4FKlp3UFBQXqef78+TEuKipSOTnl++GHH1Ru3rx5MbbdBZYuXRpj2+xNTgftFgA5PbS7t19//XX1LLsN2CH05MmTYyy7EIQQws6dOyv8e3I3uWxybzGNw7+xJwzkFpoePXqo3K+//qqe5Xtou3rkw7vHyAmASxQnAC5RnAC4lLZOmJLdji+fZdP1EELo0KFDhTm5XiO7S4agT/fb4zKLFi2KsTxKEkIIa9euVc9ya4Hd0i+3QNic/No0S5eP0gkzjbLRCbOy5DGQW2+9VeWaNm0a42nTpqnc5s2b1fNHH30UY9lFIwRfl2vQCRNAXqE4AXApbTvEJTutkc/yq9AQ9Kl9e8paPssOBSHoadWSJUtUTja0W7NmjcolTWNv5GIGIF3se2+/9pfkqYWSkhKVs3cyylMM6d69nQ2MnAC4RHEC4BLFCYBLGVlzSiK/ng8hP+fCVia2DuC/w65tyote7VaYXr16xVgeoQrhn9059u/fH2PbASMfMHIC4BLFCYBLWZ/WVYVpHJBOSScMSktLVa5fv34xtk0dFy9erJ5lg7t83BbDyAmASxQnAC5RnAC4lJGuBEg/uhKkl6d3237NLztwDB06VOWKi4tjLLcchPDPjhty205ZWZnKya4EuV6PoisBgLxCcQLgEtO6PMG0Lr14t/1gWgcgr1CcALhEcQLgUuKaEwDkCiMnAC5RnAC4RHEC4BLFCYBLFCcALlGcALhEcQLgEsUJgEsUJwAuUZwAuERxAuASxQmASxQnAC5RnAC4RHEC4BLFCYBLFCcALlGcALhEcQLgUvWkJHd7+cG9denFu+0H99YByCsUJwAuUZwAuERxAuASxQmASxQnAC5RnAC4RHEC4BLFCYBLFCcALlGcALhEcQLgEsUJgEsUJwAuUZwAuERxAuASxQmAS4mdMLMtldIN8crLK25WWKtWrRhfuHAhY58JQG4wcgLgEsUJgEtZn9bZqVv16n9/hCZNmqjcuXPnrhmHEEKbNm1i3LBhQ5WbOnVqjGfOnKlyq1atUs/nz5+P8eXLlxM/O5Br1apVi/HVq1cT/6z8XZN/z+bse5/0c+Xfs8suSbnKYOQEwCWKEwCXKE4AXEolzQ2zcfFgnTp1YjxmzBiVGzp0aIxLS0tV7qefforx5s2bVW7IkCExHjRokModPHhQPe/duzfG8+fPV7nTp0/HOB1z6P8Fl2qml+dLNeXajV2jrV+/foz79u2rcjVr1lTP9erVi7FdR5Lvtlz3DSGELVu2xPjSpUsqJ38P7e/ElStXQmVwqSaAvEJxAuBS1qd1dugph6Y7d+5UuWHDhsV49OjRKjdu3LgY161bV+XkMPXUqVMq1717d/X86quvxnjbtm0qN23atBjbaWW2Ma1LL0/TuqZNm6pnOZUbMGCAyj344IMxXr9+vcqNGjWqwn/j8OHD6rmkpCTGJ06cULlDhw7F+MiRIyp34MCBa/6MEEK46aa/xzr/ts1BYloHIK9QnAC4RHEC4FLW15zsNno5N7Vzb/m8b98+lbPHWSpLHpmR2xNCCGHu3Lkxfu2111Qu21sLWHNKr1yvOcmv+Tt06KByzZs3j3FBQYHKFRUVxXjFihUqZ4+hHD16NMatWrVSuQYNGsT47NmzKte7d+8Y2604zzzzTIwvXrwY0oE1JwB5heIEwKWc7xCXHQXs8FLuOL2RryYrq0+fPup52bJlMbbDYtnNIBuY1qVXrqd1csnilVdeUTk5HVu+fLnKya/v7TaZpGmWnEaGEML9998f4z179qjc7t27K/yZZWVlMb6R5pBJmNYByCsUJwAuUZwAuJTzCw4yMYetLHs6Wz7Xrl1b5bK95oT8ZrfQDBw4MMb2aIk8YrV69WqVk+9djRo1VM6+o3KbjLwQJIQQOnfuHOOTJ09W+LntOrCU6d9PRk4AXKI4AXAp69M6eXI5hOQGVXaal2kTJ05Uz48++miMbdMt4EbYKVDjxo1jvHXrVpWT07MnnnhC5b755psYP/XUUypnO2ds2LAhxo899pjKTZ48Ocbbt29XOdmIzm7hSfclBkkYOQFwieIEwCWKEwCXcn585XrXlTI1v5XbBezxlREjRsT4gw8+UDnZETAbOL6SXtk+vmLfc9kVoLCwUOUGDx4cY7tlZezYsTG2F3LYCw9mz54dY9txQ3a4tB1gs325LMdXAOQVihMAl3I+rZPsLlq57cAONSs7zbPD6wceeCDGt9xyi8q1bds2xjNnzlQ5ubM9GzvZmdalV667Esh3W07xQtC/B61bt1a5Y8eOxVhO/0IIoWXLlupZXgIyYcIElVu3bl2Mvd7JyMgJgEsUJwAuUZwAuJTzrgSSPTndqFGjGMs1nhD0Fvt/WTdTz7J5ewghFBcXx9geX5Ff1dpt/LmepyO/yfcpqSvA8ePH1bN87xYuXKhy7777rnpesGBBjG23y3x4fxk5AXCJ4gTAJYoTAJdyvuYk93TYtiRy34bdxi/Xkux82l5SKNmWLW3atImxvenCPgPZlrQ21KJFC/X88MMPq+enn346xrbbZj5g5ATAJYoTAJdyPq2TnTBtV8xNmzbF2G7xl0PYLVu2qNyUKVNi3KtXL5WbMWOGen7kkUdiPH36dJWTlxaeOXPm2v8BQBbVrFkzxosXL1Y529Fy48aNMU7qOOsVIycALlGcALhEcQLgUs7XnJLILf72cj85n5atTULQW/7tkRj71WydOnVi/NBDD6nce++9d4OfGEgve/xq1KhRMe7WrZvKDRo0SD0fPHgwcx8sCxg5AXCJ4gTAJVfTOtsJU379eeHCBZWT2wySnDt3Tj23atVKPcs74+fNm6dymTi53b59+xjv3bs37T8fVYudqskLDmbNmqVyJSUl6vnixYuZ+2BZwMgJgEsUJwAuUZwAuORqzSkTW+ztV7Fy60AIIXzxxRcxtl0RMnG5oFxnsp8tH7oTIvPkERW7btS8efMY79q1S+VKS0vVc76/T4ycALhEcQLgkqtpXSbYoa3sNBCC3mVrd5Pb5vLplu/D7qrKvgd2G8v1ktP2pCl89er611B20ujSpYvKyS00thmibciY7xg5AXCJ4gTAJYoTAJeq/JqTZeflTZo0ibH9alauE7A+9N9R2TWmJPZiDfk+2Qs55MUehYWFKjd58uQYb9iwocKfWRUwcgLgEsUJgEv/uWmd/dq2f//+MbZTPtmsy+4eB6yk0wh2yiWnebbzgNzCMmfOHJUrKiqKsW3AWNUwcgLgEsUJgEsUJwAuVfk1JzvXl50oQ9DrBI8//rjKLVu2LHMfDFWO7eQqOwo0a9ZM5eSakz0mJd/RFStWqFxZWVmM5QUgVREjJwAuUZwAuFTlp3VJQ+0QQti3b1+M33//fZXLx/vlkTu2OWGnTp1ibO+Q6927d4ynTZumcgUFBTEuLi5Wuaq2CzwJIycALlGcALhEcQLgUippDptKparcBNeeDq9bt26MO3bsqHI7duyIca67DJaXl6f+/U/hemX73bZrn/ISA/kOhhDC8OHDY/zzzz+rnO2cURVU9G4zcgLgEsUJgEv/uWldvmJal168234wrQOQVyhOAFyiOAFwKXHNCQByhZETAJcoTgBcojgBcIniBMAlihMAlyhOAFz6PwuSU3vim/+7AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#collapse-hide\n", "\n", "args.load_saved=True\n", "args.epochs=50\n", "args.mode = 'pretrain'\n", "\n", "# m=16 n=100\n", "args.labeled=100\n", "args.features=16\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTLabel(args.labeled/10), MNISTUnlabel(), MNISTTest())\n", "gan.pretrain()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "hMMrj6GwXcYd" }, "source": [ "#### m=32" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "a1nd5XpxXcYe", "outputId": "df5d9313-df4d-4dbc-aab9-a3aa33a621fd", "scrolled": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loss\n", "\tgenerator \t (min: 1.724, max: 2.207, cur: 1.793)\n", "\tdiscriminator \t (min: 0.373, max: 0.414, cur: 0.406)\n", "\n", "Done w cdist\n", "(44.4227202380953, 0.1734273342974863)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAADnCAYAAABcxZBBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUT0lEQVR4nO3daaiV1ffA8XWcU8t5nnNIi7RCmzWVLMrUJKmgCdMiX4QRFBVSYWRBUWRF8UOECCKiwbJBcqo007Qszcwh5zG9zmOD9/9u/9daeY733s49Zx39fl6th1X3HOE5i733s569M+Xl5QIA0dQo9hcAgJOhOAEIieIEICSKE4CQKE4AQqqVK5nJZHiUF0R5eXmm2N/hdMK9HUe2e5uRE4CQKE4AQqI4AQgp55oTgLgyGbtUc7q97cHICUBIFCcAITGtA0pUqU7j/HQ0G0ZOAEKiOAEIieIEIKQzbs2pZs2a5vqff/4p0jcBYurbt2+KlyxZkve/X9G1MkZOAEKiOAEIKZNriFXsN7dr166d4r/++svkatT4/7ravn17k5s7d26KJ06caHLvvPOOufZ/Nyp2JcivYt/bFaXvcxGRxo0bp/jEiRMm9+eff5prfW/XqVPH5Jo1a5birVu3mpxe+vB/szqwKwGAkkJxAhASxQlASAVvJcj1JrWfX+vru+++2+TmzJmT4tdee83kFi1alOK1a9eanJ5ri4js3LnzpN8FqE763m7Tpo3JHTp0KMWXXHKJyTVt2jTFbdu2Nbn77rvPXG/ZsiXFx48fN7nx48dn/fw9e/ak2LfaFLL1hpETgJAoTgBCKkgrQY8ePVK8evVq/xkpbtiwock1atQo69984YUXUuxbCZo3b57i22+/3eQ2bNhgrvUQOnK3OK0E+eXvbX0fVmZ6n6vdRed0LCLSu3fvFLdo0cLkJk+enGLf+qLNnDnTXO/du9dcN2jQIMVDhw41ud27d6d427ZtJqendV9//bXJ6X9jvpZBaCUAUFIoTgBCojgBCKkga076san/vFy74tWvXz/F/jHp77//nmK/VrRq1aoUHzx40OSOHTtmrnV7/tGjR7N+l2JjzSm/8nVv57p/mzRpkjU3adKkFO/atcvkPvvssxT79aAdO3ak+O+//zY5/9vSvx+/rjVgwIAUt2rVyuRGjx6d4rFjx5rcggULJN9YcwJQUihOAEIqSIe4f3ta00NRP0TWb1L7FoB169alWHfCitipWuvWrbP+TRH7SHfGjBkmF7m1ADHkWhY5cOBAivv3729yeiq3YsUKk9P3up/yVeae1EsY+ruIiAwZMuSknyci0q5du6yfp5docv2u84GRE4CQKE4AQqI4AQgp1AEH9erVy3rdp08fkysrK0uxbsUXsfNk30rwyCOPmOvXX389xf369TO5hQsXVuRrAyLy7zXTrl27pvimm24yuS+++CLFHTp0MDn9ikhl1nVq1bI/Z/3KjP69iIh8+umnKb7oootM7sILL0zx4cOHTa6i38f/ln0LT0UwcgIQEsUJQEihpnV+QyztqaeeMtd6yuWH03oTrv3795ucf6Q6bty4FL/99tsmp9sOCrHRO04veqND3xKgDyoYPHiwyU2fPj3FustbxLYu+GncVVddZa6fffbZFC9evNjkOnXqlOKBAweanN6pIxf/+bpjvSrTOI+RE4CQKE4AQqI4AQgp1JqTfxVA7+z3/vvvm9zw4cNTvH79epPTuwv49gC9O6CIfXXA72TI6yuoDH//6p1dZ82aZXL6UAF/UMHUqVNT/O2335qcfrTvH9c//vjj5vqTTz45aey/z5EjR0yuove93xUh3xg5AQiJ4gQgpIJsNldV+lFlt27dTE6f37VmzRqT69KlS4q7d+9ucn4orDte/WZzeipZ3UPYU2Gzufwq9L3td8PQ07pcSw9LliwxuVtuuSXFvr1F3/citjVm+fLlJudbbIqJzeYAlBSKE4CQKE4AQgrVSuDpR5p+zq7Pen/uuedMrnPnzilu2bKlydWtW9dc79u3L8XfffedyelHwf7tbN92AOTi75etW7em2K8d9ezZM8X+sb4+/ODmm282OX+o5h9//JFi3y5QChg5AQiJ4gQgpNDTOr2Zuh6iitgz5PVQV0TkuuuuS/GyZctMzm+stX379hT7wxB69OiR4qVLl5rc5s2bc353QPMtO/rab5b466+/pvjhhx82Ob15YrNmzUzO726gz2/0uyLov1PsNplsGDkBCIniBCAkihOAkEKvOelXS2rWrGly559/for9jgW6PcC36a9evdpc6xaFkSNHmtw111yTYv8oWJ9ZT1vB6cXvrJrrFa+q0ve2f6VKrx1df/31Jvfuu++m2K+D6t0MRESGDRuW4hEjRpjcV199lWK/fhplNw5GTgBCojgBCIniBCCk0GtO+lDASy+91OR0+78/WUL3dJzq1BR9SsTnn39ucmeffXaKb7vtNpPTc3bWnE4v1bHG5OkePr19iojtt9uzZ4/Jvffeeyn2vX/elVdemWJ/Mkvv3r1TrA+WFRH5/fffc/7dQmHkBCAkihOAkEJP6/Tj1nvvvdfkdOv+oEGDTK6qB2D6Azf11M23MgD/hW5X0G0xInY3DL+coNtk/GsneqooYlsLpkyZYnJ33XVXin/44QeT27hxY9bPKCRGTgBCojgBCIniBCCk0GtOer7rH5tee+21Ke7QoYPJrVu3LsV63epU/CPkbdu2pXjLli0m17Vr1xTrrSlEaC3Aqel7rUmTJib36KOPpnj69Okmp7dX8ferf+1Et8n8/PPPJperTWfhwoUpLmZbASMnACFRnACEFHpap02YMMFcHzp0KMX67WsRkVdffTXFlZnWecePH0+x30lQ72bg32IHTkXvROB3YP3iiy9SXFZWZnKV6V7X7S9z5841uU6dOqXY78Spp3zFxMgJQEgUJwAhUZwAhFQya07+7exffvklxf7ATd/GX1F+7UhfHz16NGsu6ukVqBp/8Kpee8wX3W7SqFEjk9NrTpXZldLfv3rtyLcrrF+/PsW+zUC/vlJMjJwAhERxAhBSyUzr/NBaP4odMmSIyem3rPUBhSL/np7pVgO9uZyIbVFo3769yem3w/2wGKWtOqZxnl4K8G8YjBo1KsX60FcRe2CHfxPBH5Rw7rnnprhWLftTP+ecc1Ks34QQEalfv36K/e+lkBg5AQiJ4gQgJIoTgJAyudrha9SoYZKF2Pi9ovSjUX0IoYjdML5fv34mp3csELG7aN56660mp9eZ9BqTiH3j28/Zq0N5eTnvyORRgwYNzM185MiRYn0VadWqlbl+8sknU7xo0SKTW7FiRYr9Dpq+BeKee+5J8Z133mly+t/r1670q2GFaJPJdm8zcgIQEsUJQEg5p3WZTCbOPC4HfdiBiMjll1+e4qFDh5rcggULzLXuqtXDWRGRZcuWpXjTpk0mV+gN5ZjW5Veke9tPz/SjfH/Age769odu+JaWtWvXnjQWEdm7d2+KDx8+XMlvnF9M6wCUFIoTgJAoTgBCqvKak38DOlKbgeZ3LGjevLm5Pnjw4EnjaFhzyq9i39v6M/zakV5X6tWrl8npg1/9K1V+zVSvM+nXXrxi/3ZZcwJQUihOAEI6LVoJCsEPvSuzCVg+MK3LL+7tOJjWASgpFCcAIVGcAIRUMjthFluh15iAMx0jJwAhUZwAhJRzWufPf9OHAQCl7KyzzjLXVd3Iv2HDhin2HdqlIurvnJETgJAoTgBCojgBCCnn6ysAUCyMnACERHECEBLFCUBIFCcAIVGcAIREcQIQEsUJQEgUJwAhUZwAhERxAhASxQlASBQnACFRnACERHECEBLFCUBIFCcAIVGcAIREcQIQEsUJQEg5z63LZDJn9AbjmUwmxcXea728vDxz6v8KFXWm39uRZLu3GTkBCIniBCAkihOAkHKuOZ3pir3OBJzJGDkBCIniBCAkihOAkChOAEKiOAEIieIEICRaCQAYNWvWTPE///xTtO/ByAlASBQnACExrQNgFHMqpzFyAhASxQlASBQnACGx5gQUmd5x1atVy/5Ezz777BR369bN5B544AFzfezYsRTv2LHD5N54442sn7lv374Kfbc///wzay4fGDkBCIniBCCkkp3W1a9fP8WtW7c2uT59+qRYD1FFRI4cOWKud+/eneLatWub3Pr161N8/Pjxqn9ZwNHTs7Zt25rc/fffn+KDBw+a3A033JDiadOmmdyGDRvM9Q8//JDiCRMmmNyCBQtSvGXLFpPTHeL16tUzuYYNG6Z4586dJrdnz56T/g2Rqk0BGTkBCIniBCAkihOAkDK5NvGPdPBgy5YtzXXPnj1T/PLLL5vcjBkzUrxkyRKTW7VqlbkeOnRoirdt25b175SVlVXyG+cXh2rmVyHubd0GcOLECZN78cUXU6zXT0VE5s+fn+KLLrrI5N58880U6zUeEZFDhw5l/Xy/Zqp/9zVq2DGK/j5+HVb/Hb9+q/l/by4cqgmgpFCcAIRUMtM6P7zt1atXiv/66y+T049UdTuAyL+Hok888USK/VB03rx5KZ49e7bJFfpMO6Z1+VWIe1t3Vzdu3NjkOnbsmGJ9L4vYpYjNmzebnH4kn6970HeBt2rVKsXt2rUzub///jvFK1euzPrdKoNpHYCSQnECEBLFCUBIoV9f0Y9C9dqQiMh3332X4rlz55rcL7/8kmI9Rxb592PbKVOmpPjDDz80Ob3O5OflhV5zQunRj+G7du1qcvqVkebNm5ucfuXKr6dWxzqT/03otdf//e9/Jvfcc8+lWP/OvA4dOphrv3ZWEYycAIREcQIQUqhpne9U7dy5c4oHDRpkcg899FCK//jjD5PzUznNd8rqt6w//vhjk9Nd6UzjcCp+6q+XJYYMGWJyy5cvT/Fll11mcrqFpTKd1rnUqVPHXN94440pXrduncndc889Kda7J4iIbNq0KevfPHr0aIqrMo3zGDkBCIniBCAkihOAkEKtOZ1zzjnmWs/TDxw4YHL6NZRca0yeX9e6+uqrU7xr1y6TW7p0aYr9joB6fg2cjH6c3qZNG5PT7QJ6V0oRkUaNGqV47969JqfXoPw6qL9HzzvvvBTrdSQRkeHDh6fY78Yxbty4FE+cONHk9u/fL4XCyAlASBQnACEVfVqnH7/6TtlmzZqleNiwYSZX0eGlf7zrh9dXXHFFinXrgojI559/nuIo58cjLj/N0o/Tn3/+eZPTUz7/uF5P6/wuGhdeeGGK165da3IDBw4013rJYsCAASY3adKkFL///vsmV8ipWy6MnACERHECEBLFCUBIRV9z0m9ujxkzxuR++umnrP+fbgnw60E616JFC5PTj1BF7FvfixYtMjndouAf01b3OfEoPX59U98//n757bffUtypUyeT0y00a9asMTm9PuVfifG7xU6ePDnF+kAFEZHt27enOOp6KiMnACFRnACEVPRpne54nTVrlskNHjw4xT///LPJ6WmW38hL7zwwfvx4k9PdryL2Ee/q1atNTr9V7s8EAzzfSqCncv6MubZt26bY76qhO8afeeYZk9PLIF9++WXWzxOxh3voaZxI3KmcxsgJQEgUJwAhUZwAhFTlNSe9HiNSuZ0BsvEbretN4PVrJiJ2B4NRo0aZnG4PKCsrM7mRI0eaa/06gN8VQa8zsRMm/gu/o+XOnTtT3Lp1a5PT60H6gFgRkffeey/Fet1KROSxxx4z13oNqmbNmlk/IypGTgBCojgBCIniBCCkTK61lEwmU+0LLfpVE79L5ejRo1OsezZERO68884Ur1ixwuSWLVuWYn8KRPv27c31iBEjUvzqq6+anN6SIh9rav9FeXl55tT/FSqqEPd2Lu3atTtpLCJywQUXpHjHjh0mt2TJkhQ3bdrU5PwBnH379k2x/40cPHgwxfpElWLIdm8zcgIQEsUJQEhFn9ZpujVfxD6+920GdevWTXGuTeD9bgL+dZbvv/8+xX4K6F9nKSamdflV6Hu7cePG5nrx4sUp9q9mvfbaayn2hx/oFgC/C4L/jAYNGqT4gw8+MDl9b48dO9bk/O6b1Y1pHYCSQnECEBLFCUBIRd8yRfOP6/Wakz9Us6L8WpU+0UXEPsb95ptvqvQZwMnodVG/5tOtW7cU33XXXSb3448/prgyO67u3r3bXOvfjF7jEhFp2bJlivUrXCK27aCYWwUxcgIQEsUJQEihpnXV8eZ/q1atzLU+zFDEdp4Xuwscp5fu3bun2B9qqQ8cWLp0qcnl6/AM/Xfeeustk+vXr1+K77jjDpPTOx/Mnz8/L9+lKhg5AQiJ4gQgJIoTgJCqvObkW+cj7RSpd+nU834RkTZt2pjrjz76KMWR/g0oHt9+UtXXOV555ZUUP/300yb30ksvpVifFlRdVq5caa47d+6c4o0bN5qcPvCzmBg5AQiJ4gQgpCpP6yJPgfTuBsOGDTM53ym7a9eugnwnlI6qTuP8UseDDz6YYr+RYq7/rxC/rauuuirF8+bNM7mLL744xTNnzqz275INIycAIVGcAIREcQIQUqjXV/JFHyDYsWNHk9u6dau51rsF7tu3r3q/GMLIV7uA5teK9AEZY8aMMTl9r/lDDPRul5VZf/JrV126dMn6306bNi3F/fv3N7nJkydX+DOrEyMnACFRnACEdFpM6/xwVm+k1bx5c5ObPXu2uS7mZloonkJs4q93BZg6darJ3X777SkePHiwyb3wwgsn/RsiImVlZSn2Uz7frrBnz54U+9+I/rtz5swxuShtQoycAIREcQIQEsUJQEinxZqT3oVAxD4mXrhwock1adLEXOvHtkB18TsPbNq0KcV+FwD9+ohePxUR+fjjj1PcqFEjkzvrrLPMtW5X2LBhg8kdPXo0xX6tKspvgpETgJAoTgBCyuR6bFjo8+Sryg9L9bSuZ8+eJqfP5BKxw9vIsp0nj6oplXu7Mqq6u0GxN47Mdm8zcgIQEsUJQEgUJwAhnRZrTp6fQ+cSpVX/VFhzyq9SvbdPR6w5ASgpFCcAIZ0WHeJeqUzVUDz5enyu/w73XX4xcgIQEsUJQEgUJwAh5WwlAIBiYeQEICSKE4CQKE4AQqI4AQiJ4gQgJIoTgJD+D5c7MoABbAbVAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#collapse-hide\n", "# m=32 n=100\n", "args.features=32\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTLabel(args.labeled/10), MNISTUnlabel(), MNISTTest())\n", "gan.pretrain()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "c36f8R-VXcYi" }, "source": [ "## Main Training" ] }, { "cell_type": "markdown", "metadata": { "id": "wH8z24czXcYi" }, "source": [ "#### Improved GAN Objective + Triplet loss\n", "\n", "For training the Improved GAN with Triplet loss, I have followed the Improved GAN objective (Salimans et al. 2016) with Triplet Loss in the Discriminator. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "eBc-UErZXcYj", "outputId": "235512a0-56e0-421f-a68e-b5db4e99d756" }, "outputs": [], "source": [ "#collapse-hide\n", "\n", "# Reference: https://github.com/Sleepychord/ImprovedGAN-pytorch/blob/master/ImprovedGAN.py\n", "def train_Discriminator(self, x1, x2, x3, x_unlabel):\n", " output_unlabel, output_fake = self.D(x_unlabel), self.D(self.G(x_unlabel.size()[0]).view(x_unlabel.size()).detach()) \n", " loss_supervised = self.tripletloss(self.D(x1), self.D(x2), self.D(x3))\n", " logz_unlabel, logz_fake = log_sum_exp(output_unlabel), log_sum_exp(output_fake)\n", " loss_unsupervised = 0.5 * torch.mean(F.softplus(logz_fake)) + 0.5 * -torch.mean(logz_unlabel) + 0.5 * torch.mean(F.softplus(logz_unlabel))\n", " loss = args.unlabel_weight * loss_unsupervised + loss_supervised\n", " self.Doptim.zero_grad()\n", " loss.backward()\n", " self.Doptim.step()\n", " return loss_supervised.item(), loss_unsupervised.item()\n", " \n", "def train_Generator(self, x_unlabel):\n", " fake = self.G(args.batch_size).view(x_unlabel.size())\n", " mom_gen = self.D(fake, feature=True).mean(dim=0)\n", " mom_unlabel = self.D(x_unlabel, feature=True).mean(dim=0)\n", " loss_feature_matching = torch.mean((mom_gen - mom_unlabel).pow(2))\n", " self.Goptim.zero_grad()\n", " self.Doptim.zero_grad()\n", " loss_feature_matching.backward()\n", " self.Goptim.step()\n", " return loss_feature_matching.item()\n", " \n", "def train(self):\n", " plotlosses = PlotLosses(groups={'loss': ['supervised', 'unsupervised','generator']})\n", " for epoch in tq(range(args.epochs)):\n", " self.G.train()\n", " self.D.train()\n", " unlabel_loader1 = DataLoader(self.unlabeled, batch_size = args.batch_size, shuffle=True, drop_last=True, num_workers = 4)\n", " unlabel_loader2 = DataLoader(self.unlabeled, batch_size = args.batch_size, shuffle=True, drop_last=True, num_workers = 4).__iter__()\n", " label_loader = DataLoader(self.labeled, batch_size = args.batch_size, shuffle=True, drop_last=True, num_workers = 4).__iter__()\n", " \n", " # loss_supervised = loss_unsupervised = loss_generator = 0.\n", " losses = {'supervised':0,'unsupervised':0,'generator':0}\n", " for (unlabel1, _) in unlabel_loader1:\n", " unlabel2, _ = unlabel_loader2.next()\n", " x1,x2,x3 = label_loader.next()\n", " if args.cuda:\n", " x1, x2, x3, unlabel1, unlabel2 = x1.cuda(), x2.cuda(), x3.cuda(), unlabel1.cuda(), unlabel2.cuda()\n", " \n", " l_supervised, l_unsupervised = self.train_Discriminator(x1, x2, x3, unlabel1)\n", " \n", " losses['unsupervised'] += l_unsupervised\n", " losses['supervised'] += l_supervised\n", " \n", " generator_loss = self.train_Generator(unlabel2)\n", " if epoch > 1 and generator_loss > 1:\n", " generator_loss = self.train_Generator(unlabel2)\n", " losses['generator'] += generator_loss\n", "\n", " batch_num = len(self.unlabeled) // args.batch_size\n", " for key in losses:\n", " losses[key] /= batch_num\n", "\n", " plotlosses.update(losses)\n", " plotlosses.send()\n", " if (epoch + 1) % args.save_interval == 0:\n", " torch.save(self.D.state_dict(), self.d_path)\n", " torch.save(self.G.state_dict(), self.g_path)\n", "\n", "ImprovedGAN.train = train\n", "ImprovedGAN.train_Discriminator = train_Discriminator\n", "ImprovedGAN.train_Generator = train_Generator" ] }, { "cell_type": "markdown", "metadata": { "id": "Aq-wRoLEXcYl" }, "source": [ "#### Triplet Loss and M=16 N=100\n", "\n", "Here M is the number of features generated by the discriminator and N is the number of samples on which we did supervised training, on rest of the samples we do unsupervised training." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "PTajOMULXcYl", "outputId": "b1cdd47e-1a4b-41b0-ecd6-207b420d90b8" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loss\n", "\tsupervised \t (min: 0.000, max: 0.079, cur: 0.000)\n", "\tunsupervised \t (min: 0.351, max: 0.453, cur: 0.422)\n", "\tgenerator \t (min: 0.511, max: 1.353, cur: 0.511)\n", "\n", "Done argsort\n", "(96.17309999999998, 0.9155653700006265)\n", "CPU times: user 53min, sys: 2min 10s, total: 55min 11s\n", "Wall time: 1h 5min 43s\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAADnCAYAAABcxZBBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUJ0lEQVR4nO3dWWyV1RbA8V2m0jLP8wwyi6IBDCFBEw0zBqNpfJBIJMTgg8KTxGic4/RgNMYHYwhCVCBGwQERaUUQKLPMM4KMZZ7n3oebu11ryTm3hdNz1mn/v6f1ZVF60O+sfHt/a++dU1paGgDAm2qZ/gAAcCsUJwAuUZwAuERxAuASxQmASzWSJXNycniV50RpaWlOpj9DZZIt93ZOjv7fXhnfrie6t3lyAuASxQmASxQnAC4lnXMCkFmVcY6prHhyAuASxQmAS5VyWCdfv3bq1Enl9u7dq67lY3O1atUS5qry4zWQCTw5AXCJ4gTAJYoTAJdyks2lZLrFX84d5efnq9ylS5cS/lyPHj1ifPnyZZXr1q2buv79999v+ftCCOHq1asxvnbtWhk+ccVh+UpqZfrexj9YvgIgq1CcALjkupVADjkvXLigcl26dIlxs2bNVG748OEx/uWXX1SudevW6vruu++OcXFxscrZ1gIA6cO3D4BLFCcALlGcALiU8Tkn+fo+2a5/7dq1Uzk5z9SnTx+V69+/f4xlO0AIIaxdu1ZdHzp0KMZt27ZNmKtRQ/+nun79egBQcXhyAuASxQmAS646xGvWrKmuGzRoEONRo0apXL169WJ85coVlTt37lyMb968qXIbN25U1xcvXkz4eYYNGxbjGTNmqJzsUE/HjgV0iKdWZegQt9MgtWrVUtfyvszNzVU52Yqzbds2lZPfGbsyoiLudTrEAWQVihMAlyhOAFxKeyuBHSdLPXv2VNfnz5+PcUFBgcrNnz8/xtu3b1e5HTt2xPjEiRMqZ+egpLp166rrM2fOxLhRo0YJ/2xJSYnKsWtm5ZWXl6eu5dyjXe4k77VkbTJ2rjXZ75MtNA0bNlQ52wozbty4GK9atUrlunbtGuOioiKVk+02+/fvT/jZKhpPTgBcojgBcCntwzr76CtbAiZNmqRy8nHTbi63aNGiGMtO7hD0cLA8Qyw5jAshhAULFsR46NChKidbEk6dOqVysnucIV72k0Myex82b948xvY+kPe63Q1DTgW0aNFC5eRwrXv37gl/34ABA1Sufv366loO8+QQL4QQfv311xg/++yzKvfpp5/G+PDhwyqXznubJycALlGcALhEcQLgUoXMOdlWebm8RC5JCSGEXr16xVguFwlBtwTMmzdP5Y4cORJju0vm7Y6F7c/JZTBLly5VuTFjxsR43bp1Krdz585b/h3ITnXq1ImxnXOSczDVq1dXudq1a8f4iSeeULndu3fHWO7GGoKe47LfFzkvO3v2bJWzc17ycI/PPvtM5e67774Y21aYfv36xXjhwoUql87dOHhyAuASxQmASxUyrLO7BMjHYtsNKx9Fv/nmG5Xbt29fjO1BBfKRNVnXd6rcuHFDXRcWFsbYdvHK7nGGddlPThvYVpiBAwfGeMWKFSon2wB69+6tcmPHjo3x3LlzVU5OGdiphldeeSXGduNE20rQsWPHGNvhoTy/cevWrSq3efPmGNtzH9PZGsOTEwCXKE4AXKI4AXApLctX5G6Tth1/8uTJMZZLUkLQY2F7UEG655latWqlcnIHBTsul8sP7E6bdokM/JPzLPagCzkXapeoyPmpDRs2qFynTp1ifO+996rc9OnTY2zvFznPZOdB7Wt+2dogf18IIRw9ejTGffv2VbmVK1fGOB3fs0R4cgLgEsUJgEtpGdbJjnG7oZskX2+GEMK3334b40ys7peP8AcPHlQ5Oayzm3zJ9gHbLS+7f9mxwCe7MZwcHtnpBZmzPyd3y7CbHr766qsxlm0FIYTw888/x9gO65J1aNvN52T7gDzLMQS9E8Lrr7+ucmvWrIlxJu9RnpwAuERxAuASxQmASxUy52QP95NjcdmaH4I+qCDZa3b72rQi2FXlcqlN06ZNVU4eNvjYY4+pnFwt/vfff6vcyZMnY5zOFd4oOzvPIv8/2XkleV+ePn1a5eSypmPHjqnc2bNnY/zBBx+onLy37ByXlJ+fr67tIRzy77n//vtVbs+ePTGWO3yEoL+/dheGdOLJCYBLFCcALlXIsM6ery4fKf/66y+Vk2e2r169WuXkI3Oyc7+sZK937a4I8nfY7l/5s7aVQP7ZqVOnqpzsBpbng4Xw71XtyC7J7ju7GkDeP3bKQl7bLmx5QEeyM+3s9MmIESPUdcuWLWO8fv16lZPfEftvkkPOTOKbAsAlihMAlyhOAFyqkDknO88jz2m342s552Rb/O2YWpKvd+1OlHLnA7tjgD0MQb4qtWPtZAcIylXdBQUFKvfUU0/FeNOmTSpnX/ei8kjWGiLnkULQy7g6d+6scnL5k52jbNKkSYx79OihcvYAzMcffzzGds6pXbt2Md67d6/KeVlWxZMTAJcoTgBcojgBcKlC5pzsvJJss7c7Sso5pzZt2qic3BFQjrVD0AcW2kMJ5RIRO/+1a9cudS3H9PZzy7G3XdoiT5ix7f8TJkyI8ZNPPqlychmMXdqC7GbnamQPnTyBKAQ992n7o+T9LLc2sT/XrFkzlSvPrpVffvlljO13xAuenAC4RHEC4FJadsKUr1jtwX9yp0i5c2AI+nWnbeOXj7vytX4IugVh//79KpdsJ077O+Rjss3JpQF2Y/vi4uIYDxo0SOU2btyY8PejcpEtAdu2bVM5Obxv3LixyslhnjxYNgS966o9vHbIkCHqWg4rO3TooHLvvfdewr/HC56cALhEcQLgEsUJgEtpmXOSpkyZoq5nzJgR44EDB6qcfM0v53FCCKF9+/YxtjsCymUoDzzwgMotWLBAXctXvHY5gGzrnzhxosotW7Ysxnau7JlnnonxyJEjVc4uMUDVYLcRKikpibFdtiXnMJMt4erXr5+63rJli7p+9NFHY/zRRx+pnPx7vW7j4/NTAajyKE4AXMpJtgI5JycnJcuT5W6Qsns7BL1y2naqyhX8dneB48ePx9huLC+7x+2hluPGjVPX9erVi/GLL76ocrJdwJ41L1/32kf2oqKiGDdv3lzl1q1bF+PyHNpQWlrqs403S6Xq3r5dcsWBvF9D0PeFvX9lC8Jdd92lcs8//7y6lm0z06ZNUzm5W6v9/qRbonubJycALlGcALhEcQLgUlpaCeQSEvvacvny5TG2r90PHDgQ49GjR6ucXBZiT0aRcz5yCUwIIcyaNUtdP/TQQzH+5JNPVO67776LsW1BkDsb2h00T506FWO784BcSW6X3aDqkPNKdleCsp4eZOczbbvNsGHDYiznT0MIYeXKleX8xOnHkxMAlyhOAFxKSyuBfDS1j6lymGM3y5KbbtnNuuSjsFz9HYLuHpeduPbnQtBn2Nvubfl5Bg8erHJLly6Nsd2ITp5vb//7yn9/eTaSp5UgtTLdSpCMvEdkq0sIerPG5557TuXs0E22IYwfP17lZPtApg80oJUAQFahOAFwieIEwKW0tBLIMW2ywymtxYsXx9i2IMhruwxEHuJp57jsn012wIG0ZMkSdS3/bLLDFK1Mj+/hn7wn7f0il6x069ZN5exum3KXD3t4iGx/Kc8yqnTiyQmASxQnAC6lfbO58pCPtPbRM1WPomU966s8Z4IBd0JOReTl5amcbA/4+uuvVe61115T13KDRHu2Yjbczzw5AXCJ4gTAJYoTAJdczzkBVZFsJbCHd8gdWeUusiGE8NJLL6lruWSlsLAw4e+glQAAyoHiBMAlhnVAhtlVDHLFwfDhw1VO7lJgh2NDhgxR1wUFBTF+4YUXVE62FtjVF17aDHhyAuASxQmASxQnAC6lZSdM3Dl2wkwtz/d2rVq1YmxbCeShmv369VM5O3e0ZcuWW8Yh+Nodg50wAWQVihMAlxjWZQmGdalVGe9t25LgaeiWDMM6AFmF4gTAJYoTAJeSzjkBQKbw5ATAJYoTAJcoTgBcojgBcIniBMAlihMAlyhOAFyiOAFwieIEwCWKEwCXKE4AXKI4AXCJ4gTAJYoTAJcoTgBcojgBcIniBMAlihMAlyhOAFyqkSxZGc/2ylacW5da3Nt+cG4dgKxCcQLgEsUJgEtJ55wqg+rVq6vrGzduZOiTACgPnpwAuERxAuBSpR/WMYwDshNPTgBcojgBcIniBMAlihMAlyhOAFyiOAFwieIEwCWKEwCXKE4AXKI4AXApa5av5OTozfLkbgN2iUppKZscAtmOJycALlGcALjkalhnh24jR46M8W+//aZyTZo0ifHNmzdVbtq0aTGeP3++yq1Zs0ZdHzly5PY+LFBJySkTO0WSbMok1dMpPDkBcIniBMAlihMAl3KSjRNv9+DBatV0zZNzQslyzZs3V7lLly7FuH///irXqFGjGPfo0UPl6tevH+OSkhKVs20Hs2fPjrHn+ScO1Uwtz4dqyrlX+32R13au1X6Xa9T4Z0q5Zs2aKie/P/YQkAYNGsT45MmTKpefnx/j0aNHq9z3338f4z179qjctWvXQiIcqgkgq1CcALhUIa0E9nGzrDnZOhCCbgOQj5ohhLBz584YL1u2TOXOnDkT4549e6rclClT1PWHH34Y4+7du6vcjh07En5WIFXy8vLUdW5ubowbNmyocm3atImxnL4IIYS2bduq68GDB8e4oKBA5V5++eUYJxvWTZ8+XeUmT54c4/fffz8kcv369YS5suLJCYBLFCcALlGcALhUIa0E5foA4rXpxIkTVW7t2rUxPnv2rMrt2rUrxsnmseyr2MaNG6vrd955J8aff/65yi1fvrxMvyMdaCVIrXS3EtilWbVr145xr169VO7ChQsxbtq0qcpNnTo1xrINJgT9mj+EELZt2xZj2yZjP4/Url27GO/du1fl5HVxcbHKjRgxIsYnTpxI+PdbtBIAyCoUJwAuZXxYJ19j2iGXfG36559/qlyqhlnydeybb76pcm+//XaMDx06lJLfd7sY1qVWpjvE5aqGunXrqtyDDz6YMFdUVBTjAwcOqNzx48fV9cWLF2Nsh3H16tWLsVyJEYLuLJed5CGEMH78+Bjb70utWrVifPnyZZVL9n1lWAcgq1CcALhEcQLgUsZ3wpS7BMhXqCGE0LVr1xjbV5pyicqdkHNuP/zwg8rJZQV2zM4hCrgTcn5o+PDhCXOPPPKIysl79ODBgypn53nk3FHr1q1VTrYdtGjRQuXkfFhhYaHKffzxxzG2S8rkDgap+H7w5ATAJYoTAJcyPqyT7Ops2aFtO2VTNayTm2DZXQnkK9bdu3en5PcBIegO6pUrV6qcbG+ZM2eOysnX9XZY1bt3b3Ut82PHjk2YW716tcrNnDkzxrbNQLIbyHHAAYAqgeIEwCWKEwCXXM05nTp1Sl136tQpxqnYWe9W5Dj5nnvuUbn169fHWK4iD+Hfr22B8pD3nV3dL+/D06dPq5ycV5o0aZLKLV68WF3LuSu5lCWEEM6dOxfjL774QuXkn7UHI8h2H9teI5ei2YNEbgdPTgBcojgBcMnVsM6uXJaPkEOGDFG5/fv3x7g8rzDt5nPNmjWLsd35QG4CZg9RAFLFDoHkwRq2e3zQoEExltMOIfx7WmLTpk0x/uqrr1Ruy5YtMb569WrCz5YsZ793qRjKSTw5AXCJ4gTAJYoTAJdczTlZsrVAbp4eQgjbt2+P8datW1VOth3YFgC5u2YIIbz77rsxXrhwocrJ895TPZ4G/sfO3cg2lc2bN6vcsWPHYty+fXuVs4fAnj9/PsYVtZNsReLJCYBLFCcALlGcALiU8dNXysqeAiH7PWwvyJIlS2Jcp04dlZP9USHorVjswZ3yUM+SkpJyfuLU4vSV1PJ0b9veO3kySpMmTVSuZ8+eMe7cubPK2ROC5HyVPGDTG05fAZBVKE4AXHLdSiDZHQs2bNgQY/vo+8Ybb8S4oKBA5ewQUK66tocU2tYCIFXkfdeqVSuVk20rNicPP7DtAHbqwfNQrix4cgLgEsUJgEsUJwAuZc2ckyUP8LNbR4wZMybGdh5JLkkJIYTJkyfH2C6DAVKlbt266loeZGmXRj399NMxtrtNytNX7KGaixYtuuPP6QlPTgBcojgBcClrh3VXrlyJsdzx7/+xBxG2bds2xvv27bvjz4WqRQ677GqL3NzcGHfs2FHl5G4Zffr0Ubm5c+fGeMCAASonf4dtdbGHXGY7npwAuERxAuASxQmAS1mzK0GqyHmAEEKYN29ejPPz81Vu6NChMc70TpjsSpBaye7t2z1AVR4qGUIIDRs2jLFdNiV3ppTtASGE0Lhx4xgnO/1k1apV6lqeVpRN2JUAQFahOAFwKWtbCW6XPPwgBL0x3axZs1SuPId1ovIo6zAuBD1NIHcaCCGELl26xPjo0aMqJ1sQRo0apXI//fRTjDt06KByf/zxR4wvXbpU5s+ZjXhyAuASxQmASxQnAC5VuTknO48kN5Pv27evyslV39lwCCHST95PEyZMULnFixfHWB4CG0IIb731VsLcww8/HOPZs2ernDzEoLLfkzw5AXCJ4gTApSo3rLOve3/88ccYyw3AQtAbhNkDFoAQ9NDKHjAgz1YcP368ysmdB+w9WVRUFONkHeKVHU9OAFyiOAFwieIEwKUqN+dkWwnmzJkT45kzZ6qcPBxhxYoVFfvBkBXs/JCcc7Kv9uV85okTJ1ROtg/YXKZ3wPCCJycALlGcALh028O6atV0XcuWblV5oEEI+jHdDvmKi4tjnK3/XqSW/f8uNyi0r/3z8vJivGbNGpU7fPhwjO29xbDuv3hyAuASxQmASxQnAC5VuQMO7NnzLVu2jLGcB/CGAw5SK1X3tpwvsvNRcpfM+vXrq9zp06djXNkOwywvDjgAkFUoTgBcqnLDumzFsC61uLf9YFgHIKtQnAC4RHEC4FLSOScAyBSenAC4RHEC4BLFCYBLFCcALlGcALhEcQLg0n8AWZspQFi0SwQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#collapse-hide\n", "\n", "%%time\n", "args.load_saved=True\n", "args.epochs=70\n", "args.labeled=100\n", "args.features=16\n", "args.mode = 'train'\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTTriplet(MNISTLabel(args.labeled/10)), MNISTUnlabel(), MNISTTest())\n", "\n", "gan.train()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "lfSL_f39XcYo" }, "source": [ "#### Triplet Loss and M=16 N=200" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "drYCnOpaXcYo", "outputId": "3641dd80-91ed-4d5d-fdb8-4264ccd894a4" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loss\n", "\tsupervised \t (min: 0.000, max: 0.106, cur: 0.000)\n", "\tunsupervised \t (min: 0.368, max: 0.463, cur: 0.413)\n", "\tgenerator \t (min: 0.487, max: 1.185, cur: 0.487)\n", "\n", "Done argsort\n", "(96.44481666666665, 0.931573062438397)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAADnCAYAAABcxZBBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAUIUlEQVR4nO3da7DV4xfA8XWk+z11RFJJqdRQiVySUW5hMJiSMcIMXhjXGNNoMINxG1PTDI0XlYk0DSqpKDGiCY1KF6Wr7qc7pfvl/F/9H2st9tY59tln7dP382r9Zp2zzy+eveb3PL/nUlRaWioAEM0plX0DAPBPKE4AQqI4AQiJ4gQgJIoTgJBOzZYsKiriVV4QpaWlRZV9D1UJbTuOTG2bJycAIVGcAIREcQIQEsUJQEgUJwAhUZwAhERxAhASxQlASBQnACFRnACERHECEBLFCUBIFCcAIWXdlQBAxTjllL+eC44fP57x54qKijJeZ/u9svB/40TPFahVq5a5PnjwYE7u5/94cgIQEsUJQEgUJwAhMeaURbZxgWrVqqXY99FPPfWv/6xHjhwxOc4JPDnpNiFi21O9evVMTo/lNGzY0OSaN2+e4h07dpjc7t27zXWdOnVSrNuyiEhxcXGK//zzT5PTY0eHDh0yud9//z3F+/btMzk9dpWLds6TE4CQKE4AQirK9vhVFTeBr1mzZsbcsWPHTvhnDxw4kGL/WK4fd313sLyPuxxwkFv5btu+LXXs2DHFvjs2bNiwFC9ZssTk9LX+DJG/tzXdJfOv+Z944okUf/vttyY3a9asFO/fv9/kpk2blvEz9d8vyzQHDjgAUFAoTgBCojgBCKlgphL4V6HZ+rT6Z/3v+Vez2t69e821Hlf6l7E5c12jRo2MOd+HR9Wl/9/7dqDb4ezZs01uwYIFKe7WrZvJ6Vf5up2JiLRu3dpc6zEh3ZZF7Dip/5xBgwaleNGiRSbXvn37FI8cOdLk9BiX/y6VB09OAEKiOAEIKVS3zs+i1bOw/Wv+bPTM2LPOOsvk/Epq/dhaltef+t5817FXr14p3rx5s8n5V8OounS78NNNdLfuxx9/NDk9tWD+/Pkmd+GFF6Z43LhxJud/9pdffklx7dq1TW7u3Lkprl+/vsnpmed+9nijRo1S7LtufjXEf8WTE4CQKE4AQqI4AQgp72NO/pVqgwYNUuz7rHrl9J49e0xOvyb1Y1X6tal/db9y5UpzXd7dBPXf7NKli8nppQpr1qwp1+ejavGv69etW5fiqVOnmtxFF12U4m3btpnclClTUrx9+3aT88tg9HfNj3npsdDTTjvN5PS4bNeuXU1Ot3s9tvtPf/+/4skJQEgUJwAh5b1b16JFC3sD6jHRv7bctGlTiv2r0KNHj6Y426zr9evXl+s+/41+TN65c6fJ6a5jrjahR+HR0198F0zPpl66dKnJ6fa0fPlyk9Pt3g91+J0P9KoGn9PTAHyXU9PTEf7tZ3ONJycAIVGcAIREcQIQUs7GnPRyjtNPP93kdP/aL/Xo3LlzivUOfCK2z+z71/nmp0Doafx+pwP97y0pKanYG0NB8Muv9PfFj1m2bds2xT169DA5PeakYxGRpk2bmmu9VMr/bPXq1TPemz5EwY/Z6qUu3bt3N7mvvvoqxbkYa+XJCUBIFCcAIeXlgIOrrroqxRs2bDC5LVu2pNhviBXpjDffrbv66qtTfOONN5rck08+mfH3OOAghnwfcODbge5WXXbZZSb31FNPpdjPwtZ69+5trteuXWuu9dDDnDlzTE53wXy3Tt+b7mKKiAwfPjzFfvZ6eYdeOOAAQEGhOAEIieIEIKScTSXQUwT8uMqKFStSrHcaELGr9iONMYnYf9Oll15qcl9++WWK+/fvn/Ezov2bUDn8FBq9/Mkfjqm/E/7gSr1LwIgRI0xO72YgYr93fkfWSy65JMVXXnmlyU2cODHFemxKROTw4cMp9jth6iVmfvy4PHhyAhASxQlASDnr1ukZoX4FdJ8+fVLsdxDQM1B37dp1wn9Pv5r1Xadshw9kOyjBv+7Vm3Dp88JERO66664Uz5gxw+T0jPE//vgj49/DycO30SZNmqTYv4LX7WfSpEkm9+GHH6bY7+LhN43Ts7l929aHcOhunIidIqA3xROxG9H5z9QrI3KBJycAIVGcAIREcQIQUoXshOn7otOmTUvxm2++aXIzZ85Msd9lT6+k9hu0n3322Sn2G6vrldy+H3zzzTeb6xdffDHFjRs3Njm9e6HfiXPr1q0p3rdvX8b7xslLfw/84a49e/ZMsV5mIiIycuTIFPtdAfS4rB8/1W1SxI4P+SUyeuz36aefNrkxY8akePHixSbnv9tarnd95ckJQEgUJwAhUZwAhFQhY05+zKVly5Ypfumll0xOb8/gD8fUc0Eeeughk9N9+BdeeMHkdL/Yb+swatQoc33eeeeleMGCBSbXrl27FPs5JBs3bkyxPhhU5O87G6Lq0m3Nz6nTY6h+7FOPJfmTWfTyFT9PTh886+dO+TmEep6TXwajtzHy86X096BVq1Ymt3DhQskXnpwAhERxAhBShXTr/GvLM844I8X+sVBPA+jUqZPJ6WUoH330kcnpV5z6UVfErsb2j77+0VvzUxn0td+RUE8tqOzDF5A/2ZY46YMBRETq1q2bYv+d+OKLL1LspwDo9uTbdjZ+WKR9+/Ypvu+++0xO/zsWLVpkctm6lfnEkxOAkChOAEKiOAEIqULGnGbPnm2u9WF/eic9ETsmtHTp0oy5XMk2xV5PaxCxuw7qaQ0idswr23KVXJ2+ghj8/89sS6z0UhO9hEtEZPXq1Sn2bfJEl4H48dMOHTqYa71Dq9/tcurUqSl+9dVXTU6PA5dlzCvXeHICEBLFCUBIFdKt83bs2JGPP/Of+VXe06dPT7GeLS5idz5YtWpVxs+kG1d4snXFs00p8XSXz++kqttaWdqIHnrws7dvuOEGc627mRMmTDC5N954I8V+B9rK7MppPDkBCIniBCAkihOAkPIy5hSZHl/wUwkGDRqUYv3qV0Rk/PjxFXpfqDzZxoD8yUIDBgxIsV/dr5el6B1fRexSEz8VRS/b8m3yggsuSLHf4cPfd0lJSYrfffddk9O7t0YZY/J4cgIQEsUJQEhF2R5hi4qKqvx7cN2t8xvK6Q3i/QZy+jWt3yysIpSWlmbeWR5lVpa2rbtZjz76qMnpnSv0QasiIp9//nmK9SoJEZFnn302xX6md3FxcYr79u1rcnqTOP1zIiKjR4821/PmzUux3qlDROTAgQMpruzpLpnaNk9OAEKiOAEIieIEIKSTfsxJv9L1SxP0WIDfPF6vHM/HIZqMOeVWWdq2Hpf0bWTgwIEp9odj3n///SkePny4yekpCf7gzMGDB6f4nHPOMTk9vjlu3DiTmzFjhrnW01/8bq2VPc6kMeYEoKBQnACEdNLNEPcrznXXbe/evSanpwv4rluuz4VHXLoLlG2DN39+od500U9T0bO7x4wZY3J6Nrk//EDvZjBlyhST81Na8jHcUJF4cgIQEsUJQEgUJwAhnXRTCfzBmXppgp7SHw1TCXKrvG3bj1nqJSS6LYmI9O7dO8W+3bVt2zbF9evXN7mJEyemePny5SandxM4ePCgyfmdXAsFUwkAFBSKE4CQqmS3Tr/u9Y/hjz/+uLkeNmxYxp+N9CqWbl1uVUTb9mcb6u+WHzI4dOhQxs+JNHs7H+jWASgoFCcAIVGcAIRUJcecqiLGnHKLth0HY04ACgrFCUBIJ92uBEAh09NkqvrOGDw5AQiJ4gQgJIoTgJCyTiUAgMrCkxOAkChOAEKiOAEIieIEICSKE4CQKE4AQqI4AQiJ4gQgJIoTgJAoTgBCojgBCIniBCAkihOAkChOAEKiOAEIieIEICSKE4CQKE4AQqI4AQgp67l1HNkcB8eR5xZtOw6OIwdQUChOAEKiOAEIieIEICSKE4CQKE4AQqI4AQiJ4gQgJIoTgJAoTgBCojgBCIniBCAkihOAkChOAEKiOAEIieIEICSKE4CQsu6EmZcbOPWvW+jQoYPJrVmzJuPvVa9ePcU1a9Y0uTp16qS4WrVqJrdhw4aMn3nkyBFzfcopf9Xu0lK7ceLx48czfg6A/44nJwAhUZwAhFTkuysmWQGbwOtunIjIsWPHUlxcXGxyAwYMSPH06dNNbsSIESmeMGGCyS1btizF/fv3N7lVq1aZ65kzZ6b4t99+MzndzdP3WRk44CC38n3AgR9eqOz2FAkHHAAoKBQnACFRnACElJepBPq1/5lnnmlyW7ZsSXGbNm1MrmXLlin+9ddfTe7tt9/O+JkHDx5McY0aNUzunnvuMdfz5s1L8aFDh0wu23gc4L388svm+pVXXkmxH2Nq3bp1ivX3Q0Rk+/btKfZtULdt/5lFRXbopnbt2inet2+fyekxMD+FRk/F2bt3r8nl8zvBkxOAkChOAELKy1SC5s2bp/jyyy83Of1qv3Pnzianu3n6UVNE5OOPP06x74716NEjxfv37ze5Vq1amet33nknxf4RNhKmEuRWedu2nwrTs2fPFJ9xxhkmp1cu+La1cePGFB84cCDj3/Of+cgjj6R4/PjxJrd7925zrbt1e/bsMTn9XfN/X0+pWbRokcnp7qH/vfKummAqAYCCQnECEBLFCUBIORtz0v1r/2pSTwkoKSkxuX79+qVYjxWJiPz+++8pnjx5ssnp16iHDx82OT0+dd5555mc75d/9913KY680wBjTrmVq/HU119/PcVjxowxuQYNGmT8PT32unTpUpPT3yU/vWb+/Pkpvuaaa0xu8+bN5nrXrl0p1uO+IiJ//PFHiv24bMeOHVN87bXXmpyeHjF16lTJBcacABQUihOAkHLWrdOzU7N9Zt26dc310aNHU9yuXTuT091DP11Av9L1r2n161ef86+CV6xYcUL3Xdno1uVWWdr22LFjU3zvvfeanG7Po0ePNrkhQ4akeMeOHSanNzLUs8VFRJo1a5Zi31XTqyEWLlxocv5vZGvPelb6TTfdZHK6y9elSxeT0xtAfvbZZybnh3NOFN06AAWF4gQgJIoTgJDysnzFr5bWatWqleIWLVqYnO5f636wiEiTJk1S3LBhQ5PT0wX8oQl6FwQRkZ9//jnFejW4iF0BXtnTDBhzyq3ytu1BgwaZ69mzZ6d427ZtJteoUaMUb9q0yeT0bhl+rEaPi/ppMrmix7zOPvtskxs6dGiK33//fZPT3zW9i6zI33c+OFGMOQEoKBQnACHl/YCDbPwm8Po17eDBg02ucePGKdazvEVElixZkmK9ilvEPmqLiKxcuTLFfqbsBx98kGL/yO6nNlQ0unW5VRFDFpGnonj6u9a1a1eTO//881Psdx7Qs8L1sIdI+Q9toFsHoKBQnACERHECEFKoMSdP7+Q3cOBAk9NTBH766SeT01Pz9fIUkb9PO9DTFfQ4lojtb+t7EREZNWpUiv00h4rAmFNuVXbbzjc/nUe39dtvv93krrjiihQ///zzJqfrxbp163Jyb4w5ASgoFCcAIYXu1ml6RquIyLnnnptif/iBXp3tdyXwrz91F7Bp06Ym179//xTPmjXL5PRs8lw93mZDty63IrXtiqK/M/6gBL2JnD67UUSkXr16Kf7+++8r6O7+QrcOQEGhOAEIieIEIKSCGXPKJtuuB162f68fu7rzzjtTfMcdd5icPtTzvffeO+G/UV6MOeVWobTtsvDjsvpQg+uuu87k9G4gnt7R04/RVgTGnAAUFIoTgJAoTgBCOvXffyS+XI3x+P61ni911llnmZzeHuLiiy82uR9++CEn9wNk48eY/K6veisUP8akTzp67rnnTC4f40wngicnACFRnACEVCW6dbniu4eLFy9Osd8Jc8GCBSn2O3gW6u6IiE8fjOCXpOjlViL2AMy+ffua3N13353iKN04jycnACFRnACERHECEBJjTln06tUrxW+99ZbJ6de2q1atMjnGmZArfmmWPgCzVatWJqenB4jYcdIhQ4aYXL5PDyoPnpwAhERxAhAS3TrFTwnQuwDqldoiIiNGjEjxokWLKvbGcFLRXbk2bdqYXM+ePVNcs2ZNk1u4cKG51gdgbtiwweQKYeiBJycAIVGcAIREcQIQ0kk/5qT79/7gzMceeyzFM2bMMLmOHTv+42cA/5WeIqAPfRWxu2EUFxeb3Ndff22ut2zZ8o+/Vyh4cgIQEsUJQEgF063Tq7FF7EZbvlulZ7/6nL9u1qxZinv37m1yy5YtS/HMmTNNTm/kVYiPzIhDH2IpItKnT58U79q1y+T+/PPPFH/66acmt3r1anNdCLPAs+HJCUBIFCcAIVGcAIQUesxJb8pet25dk+vXr1/G39NT8+fOnWty3bp1M9cPPPBAiufMmWNy3333XYpLSkpM7ptvvkmxH8cqhKUBqFy6zfj2o6cI6PYpIvLwww+neOXKlSZ3+PDhXN5ipePJCUBIFCcAIYXq1vnHWz1j23fr9GvSHj16mJx+vG3ZsqXJbd682Vw3aNAgxWPHjjW57t27p9h31fyBB0BZPPjggyleu3atyel2N3ToUJNbt25dio8cOVJBdxcDT04AQqI4AQiJ4gQgpKJsr72Liooq9Z24nkrwzDPPmJw+8HLr1q0m16JFixSvX7/e5G677TZzPW3atBT7vv/GjRtTXNnTA0pLS9n6IIfy3bb9eOqtt96a4k8++cTkrr/++hQvXbrU5Ap9p4F/kqlt8+QEICSKE4CQQnfrTnQTN71DgYhIw4YNU7x7926Tq+zuWXnRrcutimjbnTp1MtcrVqxIcePGjU1uypQpKX7ttddMTg9Z7Ny50+T279+f4kLfdeD/6NYBKCgUJwAhUZwAhBR6zAl/Ycwpt/LRtvV0gUmTJplc9erVU6ynzIiI1K9fP8V+uVVVxJgTgIJCcQIQUpXs1lWrVi3FVX0WLconUtsuy2aFkydPTvEtt9xSYfeUT3TrABQUihOAkChOAELKOuYEAJWFJycAIVGcAIREcQIQEsUJQEgUJwAhUZwAhPQ/ntwN9hGVbkoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#collapse-hide\n", "\n", "args.load_saved=True\n", "args.epochs=70\n", "args.labeled=200\n", "args.features=16\n", "\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTTriplet(MNISTLabel(args.labeled/10)), MNISTUnlabel(), MNISTTest())\n", "gan.train()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "7ttRXSclXcYt" }, "source": [ "#### Triplet Loss and M=32 N=100" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "9Cs8tvz_XcYt", "outputId": "7399675c-6580-486d-98e7-854938e046a0" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loss\n", "\tsupervised \t (min: 0.000, max: 0.098, cur: 0.000)\n", "\tunsupervised \t (min: 0.355, max: 0.485, cur: 0.430)\n", "\tgenerator \t (min: 0.680, max: 2.149, cur: 0.699)\n", "\n", "Done argsort\n", "(95.97931666666668, 0.9031129058147167)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAADnCAYAAABcxZBBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAARyklEQVR4nO3daWxV1RbA8VXGAjIJMiNQKaDMgog1CoK2SvxgREMiMcYh0ZgoGhONBk3wmzYmisSBqDHGCJqogCIqRgUNULSCiMwzCCJTAYEy9n3brLXh3tfi6e065f/7tE5WH/f6cu7K2fvsvXZeVVWVAIA3Der6CwDA+VCcALhEcQLgEsUJgEsUJwAuNcqWzMvL41WeE1VVVXl1/R3qE+5tPzLd2zw5AXCJ4gTAJYoTAJcoTgBcojgBcIniBMClrEsJYnl5Z9/4sWEY9Qn3tj88OQFwieIEwCWKEwCXajTnxFgc9VVt39uNGzc21+3btw9xy5YtTa5Bg7PPDH369DG5OXPmhFjPk4nUv98nT04AXKI4AXApL9ujIDu3/aArQbJq495evny5uR42bFiIu3fvbnL5+fkh7ty5s8kVFBSE+IUXXjC53r17h/jkyZMX/mUdoSsBgFShOAFwieIEwKUaLSUAkNnIkSPN9ZkzZ0I8YMAAk2vYsGGIf/75Z5Nbv359iE+fPm1yemnBpk2bTO7YsWM1/Ma+8eQEwCWKEwCX6nwpQaNGZ0eW+jFYRKRNmzYhrqioMLn4b+s7lhIkK6l7W9+/p06dMjm9CjxeIb5q1aoQv//++ya3ePHiEB8+fNjk+vXrl/G7vPXWW+Y6LcM8lhIASBWKEwCXKE4AXMrJnJMeb7/yyismp69nzJhhcuvWrQvxd999Z3J6J/dvv/1mcpWVlSHes2ePyR04cCDj33re1c2cU7JqYz417hKguwtcd911Jjdu3LgQl5aWmpyeT9XbXEREioqKQjx58mST27x5s7m+5557QnzixIms370uMecEIFUoTgBcysmwbsGCBSFu1qxZxr+Ld2Dr3dpjxowxuZ9++inEJSUlJrdo0aLz/hsiIl9//bW51qts40ffQ4cOhTheuqD/O/Tfidjd4kkteWBYl6zaGNYNGTLEXOsphF27dsWfH+L4HtH3T/x76dGjR4jjFelxY7oPPvggxGvXrjU5T1MYDOsApArFCYBLFCcALtXKnJMeF4vYsXg8r/PVV1+FuFu3bibXunXrEPfv39/ktm3bFuLi4mKT27t3b4hbtWplcvGyg/Hjx4c43h2+YsWKEB89etTk9Cte/XpXROTXX38N8Zo1a0wu/neqizmnZCU156TnjuKlBO3atQvxvn37TC6Jucj48zp27Giun3322RC/++67Jqfv7brGnBOAVKE4AXCpVoZ1o0ePNtd6J/XBgwdNbtmyZSGOH3W3b98e4hYtWpicHh42b97c5PRu8IEDB5rcqFGjzLVexbthwwaT0690f/jhB5N77rnnQjx37lyTe/DBB0N80003mVzcPKy6GNYl62I4vEP/7lavXm1yw4cPD3F5eXnOvtP5MKwDkCoUJwAuUZwAuFQrBxzorSUiIoMHDw7xn3/+aXL6tX/cMUDPz+juASK200E8b6a7ZsbbBubNm2eus73S1Z04x44da3K628Edd9xhctOnTw9x/Hp3586dGT8PSJI+KEEv2RGx81F1PeeUCU9OAFyiOAFwKbGlBHr39PHjx01Or9LWr+5F7I7+uEF8rumzxETsf1PcoP7hhx8O8YgRI0xOD12nTJlichf638hSgmRdDEsJ9Ary66+/PuPfxTsjco2lBABSheIEwCWKEwCXEltKkO0AP52LuxLEO6s90fNjN954o8np7gL33XefycXzU0Bd0PPJelmBiN3WFc+1XugWq6Tx5ATAJYoTAJdqZYV4TO/uj5cueGq0Hnc3KCgoCHG8el03xuvVq5fJeWrkhYuXnjK55pprTE53DokbMO7fv79Wv1d18eQEwCWKEwCXKE4AXMrJnFNSB0vWhiZNmoRYN6QXsd/7+eefNzl9vn18YGESPC+xQO1q2bJliOMDMWryml8vaSksLDQ53ZG2S5cuJsecEwBkQXEC4BLFCYBLOZlz8kTPMYmI9OnTJ8TxySy6W2DcUVO3RYlbxCTB0/ov/Hf6EFYR21l19+7dJqcPodUHtIqIrFy5MsTxXG68bUrfv3F32kmTJoU4buPTtGnTENfGvV1dPDkBcIniBMClej+s090sRexhCyIid955Z4ivuOIKk5s2bVqIN27caHLxMA+INWp09uf16quvmtzrr78e4qlTp5qc3u4Vb4V68sknQ/zZZ5+ZXDw8e+yxx0KslyeIiHz77bch1gd5iPiZUuDJCYBLFCcALlGcALiU2OkrnujXtvp1qohIcXGxudZLCZYsWWJyS5cuDfEff/xhcrkel3P6SrJycW/rFjx6/lLE3k/6BCIRuyQgvie3bNkS4q5du5pcfG/rOa/58+eb3L333hvi77//PuPnz5o1S2obp68ASBWKEwCX6sWwLn5NqjtT6mGbyLmrwPUjbdxdQDeF16936wLDumTl4t7Ww674dzZgwIAQx4ey6qHc0KFDTW7ZsmUhvuqqq0yuY8eO5nrDhg0h7tSpk8lVVFSEOP796OmMhQsXSiZJdRthWAcgVShOAFyiOAFwKbXbV/Rr0s6dO5ucHuvH4/I1a9aYaz2G37lzp8nV9TwT0k3fT9dee63JjRkzJsTdu3c3Od3x4ptvvjG5kpKSELdq1crk4qUF27dvD/G8efNMTm9Z0csT4lzPnj1NbvPmzRk//8iRIyFO4mBOnpwAuERxAuBSapYSNGhg66he+d2jRw+T04/M8WPx8uXLzbVu5u75IAaWEiSrru/tvn37hvi2224zOX24RVFRkcnp5QFvv/22yTVs2NBc60MMKisrTU7/7uP/nW7IuG/fvvP/B0hyuyRYSgAgVShOAFyiOAFwKTVLCdq2bWuu9fKBeDytlwCsW7fO5OIDAz3PM6H+0lul9Ct/EXtPvvHGGyanfwfxfFB8L+t52jin55lOnDhhcvqAhQkTJpjce++9F+J4/rZ169YhTuJgTp6cALhEcQLgkuulBPrRc+LEiSanX7/Gj6xTpkwJsV7RKpLeVd8sJUhWXd/bnullOnGnDr3MIdu5eTX5nbGUAECqUJwAuERxAuCSq6UE8RYVvTu6sLDQ5PTO7XhcvGPHjhDHBw0CyC7u3KFlm6NOej6XJycALlGcALjkalind0OL2IZVixYtMrkOHTqEuKyszOSOHTsWYi/nvgOoGZ6cALhEcQLgEsUJgEs5374SLxfI1hWgoKAgxHqOSUSkvLw8xPn5+SanuxSkdbtKjO0ryWL7ih9sXwGQKhQnAC657kqg6d3QInbZwfHjx3P9dRLx2muvmetJkyZl/FuGdcnydG9f7BjWAUgVihMAlyhOAFzKOucEAHWFJycALlGcALhEcQLgEsUJgEsUJwAuUZwAuERxAuASxQmASxQnAC5RnAC4RHEC4BLFCYBLFCcALlGcALhEcQLgEsUJgEsUJwAuUZwAuERxAuBSo2xJzvbyg3PrksW97Qfn1gFIFYoTAJcoTgBcojgBcIniBMAlihMAlyhOAFyiOAFwieIEwCWKEwCXKE4AXKI4AXCJ4gTApaxdCVA9jRrZ/xtPnz4d4qoqNr8DF4InJwAuUZwAuERxAuASc07V1LBhQ3Pdpk2bEI8ZM8bkPv300xAz5wRcGJ6cALhEcQLgUr0c1jVocLbm5uXZ3un6Nf//06RJkxCPHTs2Y27WrFkmd+bMmWp/BlATenohnjKIr/W9n8Z7kicnAC5RnAC4RHEC4FK9mHOKX/NfeeWVId64caPJnThxwlxnm4Nq2rRpiAsLC03u8ssvD/GcOXOq/2VxUWrcuLG51ktRjhw5YnL6Xjt48KDJ9e/fP8QnT540uU6dOpnrsrKyEK9fv97k0jAHxZMTAJcoTgBcSu2wTncCmDBhgsnt378/xFu2bDG5miwl0J8xcOBAk/vkk09CzCpwiJw7vTB69OgQ//XXXybXo0ePEN9+++0mV1FREeKbb77Z5FasWBHinTt3mtzatWvNdefOnUO8fft2kzt69Og5398bnpwAuERxAuASxQmAS6mdc+rWrVuIn376aZMbP358iOPXtNnEcwb639HzWCIiCxcurPa/i/oj3g6lxR1Rhw0bFmJ9v4qI3HrrrSFu1aqVyc2cOTPE+/btM7nff/89xHE3jF69epnrdu3ahXjp0qUZv7dXPDkBcIniBMCl1AzrBg0aZK5feumlEM+dO9fk9PKBmrzmv+WWWzJel5aWmly80hwXh2w7/+Nh3UcffRTiBx54wOTefPPNjJ/xyy+/ZMzplebHjh0zufvvv99cf/HFFyGOh44sJQCAC0RxAuASxQmAS67nnHRHyyFDhpjc7NmzQ/zOO++Y3KlTp6r9GS1atAjx448/bnJbt24N8Z49e0yOLSsQsfdBPAd02WWXhVjPP4nY7STx/ao7BsT32fHjx8/7dyIie/fuNdclJSUhjrsS7N69O+NneMGTEwCXKE4AXHI9rOvTp0+Ie/fubXJTp04N8YUO40REnnrqqRDrR2YRu8tbD/GA84lXj+uuAXE3DL3sIB6eZRtm6V0Ml156qcnFwzrdfC5ulrhhw4aM380LnpwAuERxAuASxQmAS67mnJo1a2au9cGVcRcAPU6Ptw3oXNu2bU3u7rvvNtdFRUUhjrfBzJs3rzpfGxCRc+eO9P0bz3Xqbhn6IA2R7FujdEfW+N6Ou2Z26NAhxB9++KHJeV0+oPHkBMAlihMAl+p8WKdfv3bs2NHkRo4cGWLd+E1EZMmSJSEeN26cyenH6VWrVplcfH6Y/nf++ecfk9u1a1fW7w7o+zceKmVb4qKHZC1btjS59u3bh1g3jBOxy2s2b95scvEK9R9//DHEXbt2NTn9W/v7779NzsuQjycnAC5RnAC4RHEC4FJO5pyyjcv1HFB8LvyXX34Z4nj7iu4SoA8aFBH5/PPPQzx8+HCTW7Bggbk+fPhwiP/991+Ti8fwQCzb/IzeFnLy5EmTKygoCPFdd91lciNGjAjxtm3bTG7Hjh0hjuecdOdLEZG+ffuGOD5gQXf8iA/2qMl2sNrEkxMAlyhOAFyiOAFwKS/bmDkvLy+nCx7ibSh6ripuR9G8efMQx8v99Xj6hhtuMLknnnjCXOstK3FHTU8nVFRVVWU+zRE1lot7O9v9q7tk6rY9Iva+Gzp0qMlNmzYtxJdcconJxQe/6q0u8VoqPU+r10PFn5+LNU+Z7m2enAC4RHEC4JKrYV1S9PaVF1980eTmz59vrhcvXhxiT8O4GMO6ZOX63o6Hdfo67mipl9fE96QeDh44cCDrZ44dOzbEH3/8scmNGjUqxGVlZSYXd4StbQzrAKQKxQmASxQnAC7VecuU2qBfmz7zzDMmt3r1anNdWVmZk++Ei1s8t6uv41NTstFbvPSSGZFzlxa0bt06xKWlpSanW6hw+goA1ADFCYBL9WJYp5cOiIhMnjw5xPHq1xkzZpjruCk9kBb/b/X2pk2bQty/f3+T0x034t8PXQkAIAuKEwCXKE4AXKoXc06FhYXmWr9CfeSRR0wu24GFQJrEc075+fnmWm+L0fNPIrYDbNyl0wuenAC4RHEC4FJqh3X6EXbmzJkm9/LLL4d43bp1OftOQF2qqKgw17ohY9xpYOvWrTn5Tv8FT04AXKI4AXCJ4gTApdTOOenOA/qATRGRnj17hjgXDdoBD+LDMQcPHhzioqIikysvLw9xvMzAC56cALhEcQLgUmoOOIgbaz300EMhjndVT58+PcT1ZUU4Bxwky9O9nZQWLVqY6+Li4hB36dLF5HTnAf17EbGHL+SiawcHHABIFYoTAJcoTgBcSs1SgnjO6dFHHw3xhAkTTM5rw3YgSfFBnXEHS32IQfwbmThxYoj1shwRkUOHDiX1Ff8TnpwAuERxAuBSaoZ18VDt6quvDnH8eMuwDheDeKojbjY3ZMiQEM+ePdvk9BIB3XhOxP6e6nKHBU9OAFyiOAFwieIEwKXUbF+52LF9JVk1ubf1bv80zWcOGjQoxCtXrjQ5/buPt3/FXTNrG9tXAKQKxQmASwzrUoJhXbK4t2tXv379zPWaNWsy/i3DOgCpQnEC4BLFCYBLWeecAKCu8OQEwCWKEwCXKE4AXKI4AXCJ4gTAJYoTAJf+B+/e6o+G1gzoAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#collapse-hide\n", "\n", "args.load_saved=True\n", "args.epochs=70\n", "args.labeled=100\n", "args.features=32\n", "\n", "gan = ImprovedGAN(Generator(z_dim=100), Discriminator(output_dim=args.features), MNISTTriplet(MNISTLabel(args.labeled/10)), MNISTUnlabel(), MNISTTest())\n", "gan.train()\n", "# gan.trainknn()\n", "print(gan.evalknn(results))\n", "show_gen_images(gan)" ] }, { "cell_type": "markdown", "metadata": { "id": "NF8mUdp4XcYy" }, "source": [ "### Results" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "TeLdoPVbXcYy", "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Features Supervised samples Accuracy mAP\n---------- -------------------- ---------- --------\n 16 100 96.1731 0.915565\n 16 200 96.4448 0.931573\n 32 100 95.9793 0.903113\n" } ], "source": [ "#collapse-hide\n", "from tabulate import tabulate\n", "\n", "results = [[16, 100, 96.17309999999998, 0.9155653700006265],\n", " [16, 200, 96.44481666666665, 0.931573062438397],\n", " [32, 100, 95.97931666666668, 0.9031129058147167]]\n", "print(tabulate(results,headers=['Features','Supervised samples','Accuracy','mAP']))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "083137ae2e4e4df89372b3f7312b0df6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "108cbf55deb34806be80a9c19e50ec5f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_451b542c4b364e2fb5d063a3a8c80510", "placeholder": "​", "style": "IPY_MODEL_fdb3cbac77c045259da0e546a4f343ea", "value": " 8192/? [00:00<00:00, 18073.84it/s]" } }, "127dd236058b49479d99e5a2491683cc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "15e2ed7ab25f4fa091d4911962c67539": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1b8f808ffab0449482d1b09c4d7c14ea", "placeholder": "​", "style": "IPY_MODEL_2b39760025b541ba9a9b60927d23ea19", "value": " 70/70 [1:04:14<00:00, 55.07s/it]" } }, "19f81b3816dc4f73a7a8f07ee05f3015": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1b8f808ffab0449482d1b09c4d7c14ea": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1e8467600e7a46faa63233df862f00b1": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "20b7c0530f3840ea9a862ea39f6eae03": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cf79fd4a1bfc45bca317205cfe18f5f8", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_083137ae2e4e4df89372b3f7312b0df6", "value": 1 } }, "230d03b4d8d14ae895a790c6f5d4df3a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "244e2280b23a4ae8b34a8153b56b0647": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2b39760025b541ba9a9b60927d23ea19": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2cdba883945b4638b308203cdcc249de": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_73787f08d8d04b198006b130b4e1cbd4", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_127dd236058b49479d99e5a2491683cc", "value": 1 } }, "33277bfc9ae6499b9e308a522e40d56a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "350763325d50438881e8ca475fed4f87": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "39a2d4a4e54a4363840f80db86620b37": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3bb696c3138f40a5867b3ffb531bd3dd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "info", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_73067c77fa614cf1bab5bcb68ae4f388", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_c7c3c828e8ce4056b7dc3e72560379c0", "value": 1 } }, "3bcf6dad2c1c4fe59964c4f2abc66b54": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5a888895301a41ba84b4b6659e6e60ce", "placeholder": "​", "style": "IPY_MODEL_244e2280b23a4ae8b34a8153b56b0647", "value": " 70/70 [1:03:33<00:00, 54.47s/it]" } }, "43630c9c1bac43e9ad394abdf93a6d9d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "451b542c4b364e2fb5d063a3a8c80510": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4c2a62cad3e745c48c44758757c3b5ae": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "4d253d23fd9c4a158c72ba3e1b67ea18": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_85fd1071d78f4562b80f502c11f81f7d", "placeholder": "​", "style": "IPY_MODEL_8a43ec739f9746cbbf83c2fc32e29dff", "value": " 9920512/? [00:20<00:00, 1137523.28it/s]" } }, "50ea4da8a1a143ad8e5d7f115d8a48ac": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a5c62cb711074c1284015ce876350855", "IPY_MODEL_108cbf55deb34806be80a9c19e50ec5f" ], "layout": "IPY_MODEL_cb7d6f6d5ade44c1b06b5476f09a87b6" } }, "5809f284371a491bb50597eb6642df24": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5a888895301a41ba84b4b6659e6e60ce": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5f777d764f764d7d865fcb865d954160": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "607f34bdf0b14b1a91caf6d9844bad4e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "65f5dda52bd146f0b2f78f96c8c2ff79": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6f409f71dd4449f39d15276b855ad1fb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_93ba603c2e0146539d0154a6e5198170", "IPY_MODEL_3bcf6dad2c1c4fe59964c4f2abc66b54" ], "layout": "IPY_MODEL_bdebf5da52914d3499569f23b8351ad5" } }, "73067c77fa614cf1bab5bcb68ae4f388": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "73787f08d8d04b198006b130b4e1cbd4": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "78d3878825004a538a00f7a3937d796a": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "796ea965633e48659f231a6a824c0bd6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_43630c9c1bac43e9ad394abdf93a6d9d", "placeholder": "​", "style": "IPY_MODEL_dca919422dde4dbbae20de762b3a38b0", "value": " 70/70 [1:12:05<00:00, 61.79s/it]" } }, "7d4525322a514a09a22cabb40b0da51a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "84ab96b85f8c4d1b9a83ec346ccd64d0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "85fd1071d78f4562b80f502c11f81f7d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "87db4a7ff49c449b92f64b2c04fbe707": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ecadb2ef3d7d44f189b0d9c4285c477b", "IPY_MODEL_15e2ed7ab25f4fa091d4911962c67539" ], "layout": "IPY_MODEL_65f5dda52bd146f0b2f78f96c8c2ff79" } }, "8a43ec739f9746cbbf83c2fc32e29dff": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8f966c3006734a2aa86cf9960a3f3fad": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3bb696c3138f40a5867b3ffb531bd3dd", "IPY_MODEL_4d253d23fd9c4a158c72ba3e1b67ea18" ], "layout": "IPY_MODEL_1e8467600e7a46faa63233df862f00b1" } }, "93ba603c2e0146539d0154a6e5198170": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_b8f6540af8624f298f182a953f239daa", "max": 70, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_f2a105a8e0a64dc9bf149434a943ef43", "value": 70 } }, "98b310c7caa046aba9a0980ac3408b96": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_f22799009ed94f1a81522d40e38832a2", "max": 70, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_350763325d50438881e8ca475fed4f87", "value": 70 } }, "a5c62cb711074c1284015ce876350855": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_19f81b3816dc4f73a7a8f07ee05f3015", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7d4525322a514a09a22cabb40b0da51a", "value": 1 } }, "b8f6540af8624f298f182a953f239daa": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bdebf5da52914d3499569f23b8351ad5": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c4eeacff33494712900584e14e2c4b26": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_98b310c7caa046aba9a0980ac3408b96", "IPY_MODEL_796ea965633e48659f231a6a824c0bd6" ], "layout": "IPY_MODEL_84ab96b85f8c4d1b9a83ec346ccd64d0" } }, "c7c3c828e8ce4056b7dc3e72560379c0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "caa5754b2ab24b759ea04e0f5635016c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_33277bfc9ae6499b9e308a522e40d56a", "placeholder": "​", "style": "IPY_MODEL_5f777d764f764d7d865fcb865d954160", "value": " 32768/? [00:01<00:00, 32486.74it/s]" } }, "cb7d6f6d5ade44c1b06b5476f09a87b6": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cdd6461270d6459c8d0110dbd66a0dac": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_230d03b4d8d14ae895a790c6f5d4df3a", "placeholder": "​", "style": "IPY_MODEL_607f34bdf0b14b1a91caf6d9844bad4e", "value": " 1654784/? [00:00<00:00, 2238998.56it/s]" } }, "cf79fd4a1bfc45bca317205cfe18f5f8": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "da991ed9f2cd44ffa15613ca4eccd0f0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_2cdba883945b4638b308203cdcc249de", "IPY_MODEL_cdd6461270d6459c8d0110dbd66a0dac" ], "layout": "IPY_MODEL_78d3878825004a538a00f7a3937d796a" } }, "dca919422dde4dbbae20de762b3a38b0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ec7c7b1b5bdf45ae96b43644309a7e94": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_20b7c0530f3840ea9a862ea39f6eae03", "IPY_MODEL_caa5754b2ab24b759ea04e0f5635016c" ], "layout": "IPY_MODEL_39a2d4a4e54a4363840f80db86620b37" } }, "ecadb2ef3d7d44f189b0d9c4285c477b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "100%", "description_tooltip": null, "layout": "IPY_MODEL_5809f284371a491bb50597eb6642df24", "max": 70, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_4c2a62cad3e745c48c44758757c3b5ae", "value": 70 } }, "f22799009ed94f1a81522d40e38832a2": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "f2a105a8e0a64dc9bf149434a943ef43": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "initial" } }, "fdb3cbac77c045259da0e546a4f343ea": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }