{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "$$\n", "\\newcommand{\\mat}[1]{\\boldsymbol {#1}}\n", "\\newcommand{\\mattr}[1]{\\boldsymbol {#1}^\\top}\n", "\\newcommand{\\matinv}[1]{\\boldsymbol {#1}^{-1}}\n", "\\newcommand{\\vec}[1]{\\boldsymbol {#1}}\n", "\\newcommand{\\vectr}[1]{\\boldsymbol {#1}^\\top}\n", "\\newcommand{\\rvar}[1]{\\mathrm {#1}}\n", "\\newcommand{\\rvec}[1]{\\boldsymbol{\\mathrm{#1}}}\n", "\\newcommand{\\diag}{\\mathop{\\mathrm {diag}}}\n", "\\newcommand{\\set}[1]{\\mathbb {#1}}\n", "\\newcommand{\\norm}[1]{\\left\\lVert#1\\right\\rVert}\n", "\\newcommand{\\pderiv}[2]{\\frac{\\partial #1}{\\partial #2}}\n", "\\newcommand{\\bb}[1]{\\boldsymbol{#1}}\n", "\\newcommand{\\Tr}[0]{^\\top}\n", "\\newcommand{\\softmax}[1]{\\mathrm{softmax}\\left({#1}\\right)}\n", "$$\n", "\n", "# CS236781: Deep Learning\n", "# Tutorial 11: Variational AutoEncoders\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Introduction\n", "\n", "In this tutorial, we will cover:\n", "\n", "- Discriminative Vs. Generative\n", "- start simple- KDE and GMM\n", "\n", "- VAE\n", " - KL divergence\n", " - representation trick\n", " - VAE loss\n", " - code example\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-03-24T07:25:15.707910Z", "iopub.status.busy": "2022-03-24T07:25:15.707373Z", "iopub.status.idle": "2022-03-24T07:25:16.993157Z", "shell.execute_reply": "2022-03-24T07:25:16.992767Z" }, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Setup\n", "%matplotlib inline\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "import sys\n", "import time\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import seaborn as sns; sns.set()\n", "\n", "\n", "\n", "# pytorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchvision\n", "from torchvision import datasets, transforms\n", "from torch.utils.data import DataLoader,Dataset \n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", " \n", "# sklearn imports\n", "from sklearn import mixture\n", "from sklearn.manifold import TSNE" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-03-24T07:25:16.995346Z", "iopub.status.busy": "2022-03-24T07:25:16.995233Z", "iopub.status.idle": "2022-03-24T07:25:17.008877Z", "shell.execute_reply": "2022-03-24T07:25:17.008540Z" }, "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "plt.rcParams['font.size'] = 20\n", "data_dir = os.path.expanduser('~/.pytorch-datasets')\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Discriminative Vs. Generative" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "recall for our probabalistic notation:\n", "\n", "- Domain: $\\vec{x}^i \\in \\set{R}^D$\n", "- Target: $y^i \\in \\set{Y}$ - typically for classification, a set of classes.\n", "\n", "* When did we solve a regresion problem in the course?\n", "* what was the Target space there?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Discriminative models**, are most of the models we saw in the course so far: we're trying to learn $P(Y|X)$ \n", "for that type of models, we have to use labels, so we use supervised learning setup.\n", "\n", "**Generative models**, are models that learn $P(X)$ rather explicitly (Today), or implicitly (Next week). can also learn $P(X|Y)$ if we know the target space.\n", "\n", "and serve two perposes:\n", "\n", "1. we can use bayes rule: $(Y|X) = \\frac{P(X|Y) P(X)}{P(Y)}$ and since we only like to maximize, we can only look at $P(X|Y) P(X)$, and thus we can classify the example!\n", "\n", "2. We can sample the learned distribution $P(X)$ or $P(X|Y)$ and generate new instances !\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## KDE" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "A bit before we dive into the deep (learning/water), let's see how we might create a generative model for low- dimentional data.\n", "\n", "We have N points $X \\in \\set{R}^2$ where $X_i$ ~ $D$ and we want to learn $D$\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "def make_data(N, f=0.3, rseed=1):\n", " rand = np.random.RandomState(rseed)\n", " x = rand.randn(N)\n", " x[int(f * N):] += 5\n", " return x\n", "\n", "x = make_data(1000)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "simplest way (yet not accurate) to estimate D would be to use a histogram of the samples.\n", "\n", "Problem:\n", "\n", "* when the values are of a continues grid, each uniqe value would have a column with the hight of one.\n", " * we can choose in advance the number of bins we want to use\n", " * each datapoint that will come in a bit will be assined for the bin" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "hist = plt.hist(x, bins=30,density=True)\n", "density, bins, patches = hist\n", "widths = bins[1:] - bins[:-1]\n", "(density * widths).sum()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "* if we don't have high number of datapoints, the setimation could defer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x = make_data(20)\n", "bins = np.linspace(-5, 10, 10)\n", "fig, ax = plt.subplots(1, 2, figsize=(12, 4),\n", " sharex=True, sharey=True,\n", " subplot_kw={'xlim':(-4, 9),\n", " 'ylim':(-0.02, 0.3)})\n", "fig.subplots_adjust(wspace=0.05)\n", "for i, offset in enumerate([0.0, 0.6]):\n", " ax[i].hist(x, bins=bins + offset, density=True)\n", " ax[i].plot(x, np.full_like(x, -0.01), '|k',\n", " markeredgewidth=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "to avoid sampling only the mean of each bin, think that instead of stacking blocks we can add a block (with small width) around each point:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_d = np.linspace(-4, 8, 2000)\n", "density = sum((abs(xi - x_d) < 0.5) for xi in x)\n", "\n", "plt.fill_between(x_d, density, alpha=0.5)\n", "plt.plot(x, np.full_like(x, -0.1), '|k', markeredgewidth=1)\n", "\n", "plt.axis([-4, 8, -0.2, 8]);" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "This is starting to look more like a distribution function, yet we need to determine the width of each bin\n", "\n", "and more important, it is still discrete.\n", "\n", "We can, instead, add a gaussian around each point:\n", "\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "(-0.02, 0.22)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.neighbors import KernelDensity\n", "\n", "# instantiate and fit the KDE model\n", "kde = KernelDensity(bandwidth=1.0, kernel='gaussian')\n", "kde.fit(x[:, None])\n", "\n", "# score_samples returns the log of the probability density\n", "logprob = kde.score_samples(x_d[:, None])\n", "\n", "plt.fill_between(x_d, np.exp(logprob), alpha=0.5)\n", "plt.plot(x, np.full_like(x, -0.01), '|k', markeredgewidth=1)\n", "plt.ylim(-0.02, 0.22)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Other way could be to use a Gausian Mixture Model (GMM):\n", "\n", "
\n", "\n", "\n", "And while the original algorithm is unsupervised learning, using labels we can tweek it to learn $P(Y|X)$, where for each new datapoint, we can decide what distribution it is more likely come from." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "If you are intrested, you can [watch](https://www.youtube.com/watch?v=XLKoTqGao7U&ab_channel=VictorLavrenko) the EM algorithm\n", "\n", "and to experince with code, you may run the code below:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib.colors import LogNorm\n", "\n", "n_samples = 500\n", "\n", "# generate random sample, two components\n", "np.random.seed(0)\n", "\n", "# generate spherical data around (20, -20)\n", "shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, -20])\n", "\n", "# generate zero centered stretched Gaussian data\n", "C = np.array([[-1.0, -0.7], [3.5, 0.7]])\n", "stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)\n", "\n", "# concatenate the two datasets into the final training set\n", "X_train = np.vstack([shifted_gaussian, stretched_gaussian])\n", "\n", "# fit a Gaussian Mixture Model with two components\n", "clf = mixture.GaussianMixture(n_components=2, covariance_type=\"full\")\n", "clf.fit(X_train)\n", "\n", "\n", "# display predicted scores by the model as a contour plot\n", "x = np.linspace(-50.0, 50.0)\n", "y = np.linspace(-50.0, 50.0)\n", "X, Y = np.meshgrid(x, y)\n", "XX = np.array([X.ravel(), Y.ravel()]).T\n", "Z = -clf.score_samples(XX)\n", "Z = Z.reshape(X.shape)\n", "\n", "plt.figure(figsize=(15, 8))\n", "CS = plt.contour(\n", " X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)\n", ")\n", "CB = plt.colorbar(CS, shrink=0.8)\n", "plt.scatter(X_train[:, 0], X_train[:, 1], 0.8)\n", "\n", "\n", "plt.title(\"Negative log-likelihood predicted by a GMM\")\n", "plt.axis(\"tight\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And sample points (just $P(X)$) like that:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "samled point:[-3.06122449 -1.02040816] with prob:0.12\n", "samled point:[ 19.3877551 -19.3877551] with prob:0.24\n", "samled point:[ 21.42857143 -21.42857143] with prob:0.04\n", "samled point:[ 19.3877551 -19.3877551] with prob:0.24\n", "samled point:[ 19.3877551 -19.3877551] with prob:0.24\n", "samled point:[-1.02040816 -1.02040816] with prob:0.05\n", "samled point:[1.02040816 1.02040816] with prob:0.05\n", "samled point:[ 19.3877551 -21.42857143] with prob:0.11\n", "samled point:[1.02040816 1.02040816] with prob:0.05\n", "samled point:[ 19.3877551 -19.3877551] with prob:0.24\n" ] } ], "source": [ "\n", "flat = np.exp(-1*Z.flatten())\n", "flat /= np.sum(flat)\n", "for _ in range(10):\n", " idx = np.random.choice(a=flat.size, p=flat)\n", " print(f\"samled point:{XX[idx]} with prob:{flat[idx]:.2f}\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Variational-Autoencoder (VAE)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "The idea is very similar to AutoEncoder that we've already saw.\n", "\n", "Now we can work in an unsupervised fashion. means, we don't need labels.\n", "\n", "The idea is to use an Encoder to some latent space that we can enforce with some known distribution, and then project it back to the image space with minimal changes." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "* **Reminder:**\n", " * our goal is to learn **distributions**. \n", " * Similiarly to GANs, the trick would be to model the desired distribution by learning a deterministic function $F_\\theta$ which maps samples from a simple distribution $Z$ (i.e standard normal) to the target distribution $X$ (this is the inverse CDF trick).\n", " * Then, to generate new samples, one can sample $Z \\sim \\mathcal{N}(0,I)$ and compute $F_\\theta(Z)$.\n", " * The true conditional distribution is given $F_{\\theta^*}(z) = p_{\\theta^*}(X | Z=z)$ (assumed to be deterministic).\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Unlike GANs, which utilizes minimax games to learn the distribution -In VAEs we would like to directly maximiaze the log-likelihood of the given data.\n", "Namely, $LL = \\frac{1}{N} \\sum_i \\log p_{\\theta}(X_i)$ where each $X_i$ is sampled from the **true** distribution $p_{\\theta^*}(X_i)$.\n", "\n", "\n", "* **How can we do this?**\n", " * Note that $p_{\\theta}(X)$ can be marginalized as follows: $p_{\\theta}(X) = \\int p_\\theta(X|Z=z)f_Z(z) dz$.\n", " * **Idea 1:** sample from $p_\\theta(z)$ and use $F_\\theta(z) = p_\\theta(X|z)$ to approximate $p_{\\theta}(x)$.\n", " * **Problem:** Need too many samples...\n", " * **Idea 2:** if we knew the posterior $p_\\theta(z|x)$ it could have been useful...\n", " * Exploit algorithms such as Expectation-Maximization (EM).\n", " * **Problem:** hard to compute in our case, since we use a NN model." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "* **Instead**: we use the *variational inference (VI)* technique:\n", " 1. Approximate $p_\\theta(z|x)$ by a model $q_\\phi(z|x)$.\n", " 2. Use this approximation to lower bound $log p_{\\theta}(x)$ (ELBO loss).\n", " 3. Update $\\theta,\\phi$ based on this lower-bound.\n", " * Q: How to fit a good $q_\\phi(z|x)$?\n", " * A: By minimizing $D_{KL}(q_\\phi(z|x) || p_\\theta(z|x))$!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Kullback-Leibler divergence" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Kullback-Leibler divergence is a metric to measure quasi-distance between 2 distributions:\n", "\n", "$$ D_{KL}(P || Q) = \\sum_{x \\in \\mathcal{X}} P(x) log(\\frac{P(x)}{Q(x)}) $$\n", "\n", "In other words, it is the expectation of the logarithmic difference between the probabilities $P$ and $Q$.\n", "\n", "* Some properties of the KL divergence:\n", " * $D_{KL}(P || Q) \\ne D_{KL}(Q || P)$\n", " * $D_{KL}(P || Q) \\geq 0$ and $D_{KL}(P || Q) = 0$ if and only if $P=Q$ almost everywhere.\n", " * $D_{KL}(P || Q) = \\mathbb{E}_P[log(\\frac{P}{Q})]$.\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Bounding the marginal log-likelihood" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
\n", "By the Bayes' theorem:\n", "
\n", "$$ \\log p_\\theta(x) = - \\log p_\\theta(z|x) + \\log p_\\theta(x,z) $$\n", "
\n", "By adding and subtracting $\\log q_\\phi(z|x)$ from the RHS:\n", "
\n", "$$ \\log p_\\theta(x) = \\log( \\frac{q_\\phi(z|x)}{p_\\theta(z|x)} ) + [\\log p_\\theta(x,z) - \\log q_\\phi(z|x)] $$\n", "
\n", "Since the above is true for every $z$, we can take expectations w.r.t ${q_\\phi(z|x)}$:\n", "
\n", "$$ \\log p_\\theta(x) = \\mathbb{E}_{q_\\phi(z|x)}[\\log( \\frac{q_\\phi(z|x)}{p_\\theta(z|x)} )] +\n", "\\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(x,z) - \\log q_\\phi(z|x)] $$\n", "
\n", "By using the third property of KL-divergence:\n", "
\n", "$$ \\log p_\\theta(x) = D_{KL}(q_\\phi(z|x) || p_\\theta(z|x)) +\n", "\\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(x,z) - \\log q_\\phi(z|x)] $$\n", "
\n", "\n", "* Look carefully on what we have:\n", " * The log-likelihood is on the LHS.\n", " * The RHS is composed of two terms - where the first term is the KL-div we wanted to minimize to fit a good encoder.\n", " * The RHS depends on both $\\theta,\\phi$ whereas the LHS depends solely on $\\theta$.\n", " * So for a fixed $\\theta$...\n", " * maximizing the second term w.r.t $\\phi$...\n", " * is equivalent to minimizing the KL-div w.r.t $\\phi$!\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Let us refer to the second term as the ELBO (evidence lower bound) loss:\n", "$$ \\mathcal{L}(\\theta,\\phi;x) = \\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(x,z) - \\log q_\\phi(z|x)]$$\n", "It is called ELBO because it lower bounds the LL (recall that KL-div is always non-negative):\n", "$$ \\log p_\\theta(x) \\geq \\mathcal{L}(\\theta,\\phi;x) $$\n", "
\n", "Note that:\n", "$$ \\mathcal{L}(\\theta,\\phi;x) = \\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(z) + \\log p_\\theta(x|z) - \\log q_\\phi(z|x)]$$\n", "
\n", "And so:\n", "$$ \\mathcal{L}(\\theta,\\phi;x) = -D_{KL}(q_\\phi(z|x) || p_\\theta(z)) + \\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(x|z)]$$\n", "\n", "* The first term is called the **latent space regularization** and the second term is called the **reconstruction error**.\n", "* In VAEs we have:\n", " 1. $p_\\theta(z) \\sim \\mathcal{N}(\\bar{0},\\mathbf{I})$.\n", " 2. $q_\\phi(z|x) \\sim \\mathcal{N}(\\bar{\\mu}(x), diag(\\bar{\\sigma}^2(x))$ where both $\\bar{\\mu}(x), diag(\\bar{\\sigma}^2(x)$ are learned by an MLP. \n", " 3. The decoder $p_\\theta(x|z)$ is assumed to model $\\log p_\\theta(x|z)$ using Bernouli MLP:\n", "$$ \n", "\\log p(x|z) = x \\log g(z) + (1 − x) · \\log(1 − g(z))\n", "$$\n", "\n", "where $g(.)$ is an MLP, and the above equality is interpreted as pixelwise equality." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### The Reparameterization Trick\n", "---\n", "As you recall, in deep neural networks we use **backpropagation** of the gradients to update the weights. In the training process we need to **sample** $z$'s and forward them through the decoder, and they are sampled from $\\mathcal{N}(\\mu(X), \\Sigma(X))$. So normaly, code-wise it would look something like this: `z = torch.normal(mu_x, sigma_x)` or `z = np.normal(mu_x, sigma_x)`.\n", "\n", "* What is the problem with that operation?\n", " * The sampling operation **does not have a gradient!** So we cannot update the encoder with respect to the loss function!\n", "* Solution - **The Reparametrization Trick**:\n", " * It makes the network differentiable!\n", " * The trick is as follows:\n", " * Recall that if you have $x \\sim \\mathcal{N}(\\mu, \\Sigma)$ and then you perform standartization, $x_{std}$, so that $\\mu=0, \\Sigma=1$, then you can revert it back to the original distribution by: $x = \\mu +\\Sigma^{\\frac{1}{2}} x_{std}$.\n", " * In our case, let $\\epsilon \\sim \\mathcal{N}(0,1)$: $$ z = \\mu(X) + \\Sigma(X)^{\\frac{1}{2}}\\epsilon $$\n", " * No we can take the derivative w.r.t. to $\\mu(X), \\Sigma(X)$ and backpropagate it through the network!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "# reparametrization trick\n", "def reparameterize(mu, logvar, device=torch.device(\"cpu\")):\n", " \"\"\"\n", " This function applies the reparameterization trick:\n", " z = mu(X) + sigma(X)^0.5 * epsilon, where epsilon ~ N(0,I)\n", " :param mu: mean of x\n", " :param logvar: log variance of x\n", " :param device: device to perform calculations on\n", " :return z: the sampled latent variable\n", " \"\"\"\n", " std = torch.exp(0.5 * logvar)\n", " eps = torch.randn_like(std).to(device)\n", " return mu + eps * std" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### The VAE loss\n", "\n", "all we talked about so far is the KL div of $z$\n", "\n", "\n", "eventially, we want to reconstract the input image $X$ from this representation.\n", "\n", "the VAE Loss would be:\n", "\n", "$L(\\phi,\\theta,x) = R(\\hat{X}_{\\theta},X) + KL(p_{\\theta}(z|X),q_{\\phi}(z|X))$\n", "\n", "where the first part is the reconstruction loss of the input and the reconstucted image.$\\hat{X}_{\\theta}$ is infeere from the vector $z$ by the decoder part.\n", "\n", "we can simply use pixel-wize BCE (Binary cross enthropy)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def loss_function(recon_x, x, mu, logvar, loss_type='bce'):\n", " \"\"\"\n", " This function calculates the loss of the VAE.\n", " loss = reconstruction_loss - 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n", " :param recon_x: the reconstruction from the decoder\n", " :param x: the original input\n", " :param mu: the mean given X, from the encoder\n", " :param logvar: the log-variance given X, from the encoder\n", " :param loss_type: type of loss function - 'mse', 'l1', 'bce'\n", " :return: VAE loss\n", " \"\"\"\n", " if loss_type == 'mse':\n", " recon_error = F.mse_loss(recon_x, x, reduction='sum')\n", " elif loss_type == 'l1':\n", " recon_error = F.l1_loss(recon_x, x, reduction='sum')\n", " elif loss_type == 'bce':\n", " recon_error = F.binary_cross_entropy(recon_x, x, reduction='sum')\n", " else:\n", " raise NotImplementedError\n", "\n", " # see Appendix B from VAE paper:\n", " # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n", " # https://arxiv.org/abs/1312.6114\n", " # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n", " kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n", " return (recon_error + kl) / x.size(0) # normalize by batch_size" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "The Encoder part: we simply use 2 layers MLP\n", "\n", "The encoder takes the high-dimensional data, $X \\in \\mathcal{R}^D$, and encodes in a lower-dimensional latent space vector, $z \\in \\mathcal{R}^d$, that is, we model $q_{\\phi}(z|X)$. \n", "\n", "\n", "\n", "Since we are in a *variational* environment, and we model a distrubution $q_{\\phi}$, the outputs of the encoder are the mean, $\\mu(X) \\in \\mathcal{R}^d$ and the co-variance, $\\Sigma(X) \\in \\mathcal{R}^d$. \n", " * Remember that since we assume independce between the latent variables, the co-variance matrix is diagonal and we can represent it as a vector in $\\mathcal{R}^d$, where each value represents the variance (the $ii^{th}$ element in the co-variance matrix)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# encoder - q_{\\phi}(z|X)\n", "class VaeEncoder(torch.nn.Module):\n", " \"\"\"\n", " This class builds the encoder for the VAE\n", " :param x_dim: input dimensions\n", " :param hidden_size: hidden layer size\n", " :param z_dim: latent dimensions\n", " :param device: cpu or gpu\n", " \"\"\"\n", "\n", " def __init__(self, x_dim=28*28, hidden_size=256, z_dim=10, device=torch.device(\"cpu\")):\n", " super(VaeEncoder, self).__init__()\n", " self.x_dim = x_dim\n", " self.hidden_size = hidden_size\n", " self.z_dim = z_dim\n", " self.device = device\n", " \n", " self.features = nn.Sequential(nn.Linear(x_dim, self.hidden_size),\n", " nn.ReLU())\n", " \n", " self.fc1 = nn.Linear(self.hidden_size, self.z_dim, bias=True) # fully-connected to output mu\n", " self.fc2 = nn.Linear(self.hidden_size, self.z_dim, bias=True) # fully-connected to output logvar\n", "\n", "\n", " def bottleneck(self, h):\n", " \"\"\"\n", " This function takes features from the encoder and outputs mu, log-var and a latent space vector z\n", " :param h: features from the encoder\n", " :return: z, mu, log-variance\n", " \"\"\"\n", " mu, logvar = self.fc1(h), self.fc2(h)\n", " # use the reparametrization trick as torch.normal(mu, logvar.exp()) is not differentiable\n", " z = reparameterize(mu, logvar, device=self.device)\n", " return z, mu, logvar\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " This is the function called when doing the forward pass:\n", " z, mu, logvar = VaeEncoder(X)\n", " \"\"\"\n", " h = self.features(x)\n", " z, mu, logvar = self.bottleneck(h)\n", " return z, mu, logvar" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The decoder takes a lower-dimensional latent space vector, $z \\in \\mathcal{R}^d$ and decodes it to a high-dimensional *reconstruction* data, $\\tilde{X} \\in \\mathcal{R}^D$ , that is, we model $p_{\\theta}(X|z)$. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# decoder - p_{\\theta}(x|z)\n", "class VaeDecoder(torch.nn.Module):\n", " \"\"\"\n", " This class builds the decoder for the VAE\n", " :param x_dim: input dimensions\n", " :param hidden_size: hidden layer size\n", " :param z_dim: latent dimensions\n", " \"\"\"\n", "\n", " def __init__(self, x_dim=28*28, hidden_size=256, z_dim=10):\n", " super(VaeDecoder, self).__init__()\n", " self.x_dim = x_dim\n", " self.hidden_size = hidden_size\n", " self.z_dim = z_dim\n", " \n", " self.decoder = nn.Sequential(nn.Linear(self.z_dim, self.hidden_size),\n", " nn.ReLU(),\n", " nn.Linear(self.hidden_size, self.x_dim),\n", " nn.Sigmoid())\n", " #Sigmoid normelize in [0,1], like we normelize the input image\n", " #can delete that and norm by a function like tutorial 3.\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " This is the function called when doing the forward pass:\n", " x_reconstruction = VaeDecoder(z)\n", " \"\"\"\n", " x = self.decoder(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class Vae(torch.nn.Module):\n", " def __init__(self, x_dim=28*28, z_dim=10, hidden_size=256, device=torch.device(\"cpu\")):\n", " super(Vae, self).__init__()\n", " self.device = device\n", " self.z_dim = z_dim\n", "\n", " self.encoder = VaeEncoder(x_dim, hidden_size, z_dim=z_dim, device=device)\n", " self.decoder = VaeDecoder(x_dim, hidden_size, z_dim=z_dim)\n", "\n", " def encode(self, x):\n", " return self.encoder(x)\n", "\n", " def decode(self, z):\n", " return self.decoder(z)\n", "\n", " def sample(self, num_samples=1):\n", " \"\"\"\n", " This functions generates new data by sampling random variables and decoding them.\n", " Vae.sample() actually generates new data!\n", " Sample z ~ N(0,1)\n", " \"\"\"\n", " z = torch.randn(num_samples, self.z_dim).to(self.device)\n", " return self.decode(z)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " This is the function called when doing the forward pass:\n", " return x_recon, mu, logvar, z = Vae(X)\n", " \"\"\"\n", " z, mu, logvar = self.encode(x)\n", " x_recon = self.decode(z)\n", " return x_recon, mu, logvar, z" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# define hyper-parameters\n", "BATCH_SIZE = 128 # usually 32/64/128/256\n", "LEARNING_RATE = 1e-3 # for the gradient optimizer\n", "NUM_EPOCHS = 15# 150 # how many epochs to run?\n", "HIDDEN_SIZE = 256 # size of the hidden layers in the networks\n", "X_DIM = 28 * 28 # size of the input dimension\n", "Z_DIM = 10 # size of the latent dimension" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deu to computetional constraints, we're not going to generate big images, but going to go back to mnist" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data\\MNIST\\raw\\train-images-idx3-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dadfbe862bff4592a803c3090856d9e2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data\\MNIST\\raw\\train-images-idx3-ubyte.gz to data\\MNIST\\raw\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8a8021bc618a436ba8859e3a16348dda", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data\\MNIST\\raw\\train-labels-idx1-ubyte.gz to data\\MNIST\\raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c03076fa55b4531aef26476776e81b4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to data\\MNIST\\raw\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "689bff12828a4be2acac368d4c5cea51", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to data\\MNIST\\raw\n", "Processing...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\moshe\\anaconda3\\envs\\cs3600-tut\\lib\\site-packages\\torchvision\\datasets\\mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:141.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Done!\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "transform = torchvision.transforms.ToTensor()\n", "\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transform,\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transform)\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(8 ,5))\n", "samples, labels = next(iter(train_loader))\n", "for i in range(6):\n", " ax = fig.add_subplot(2, 3, i + 1)\n", " ax.imshow(samples[i][0, :, :].data.cpu().numpy(), cmap='gray')\n", " title = \"digit: \" + str(labels[i].data.cpu().item())\n", " ax.set_title(title)\n", " ax.set_axis_off()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0 training loss: 171.57769 epoch time: 10.520 sec\n", "epoch: 1 training loss: 128.90011 epoch time: 11.004 sec\n", "epoch: 2 training loss: 122.67897 epoch time: 11.892 sec\n", "epoch: 3 training loss: 119.56529 epoch time: 13.966 sec\n", "epoch: 4 training loss: 117.53008 epoch time: 13.197 sec\n", "epoch: 5 training loss: 116.11026 epoch time: 13.531 sec\n", "epoch: 6 training loss: 114.98391 epoch time: 13.474 sec\n", "epoch: 7 training loss: 114.12774 epoch time: 14.079 sec\n", "epoch: 8 training loss: 113.46270 epoch time: 12.615 sec\n", "epoch: 9 training loss: 112.87070 epoch time: 12.998 sec\n", "epoch: 10 training loss: 112.36000 epoch time: 12.948 sec\n", "epoch: 11 training loss: 111.94424 epoch time: 12.560 sec\n", "epoch: 12 training loss: 111.51269 epoch time: 12.668 sec\n", "epoch: 13 training loss: 111.15255 epoch time: 12.881 sec\n", "epoch: 14 training loss: 110.82297 epoch time: 13.313 sec\n" ] }, { "data": { "text/plain": [ "'\\n# save\\nfname = \"./vae_mnist_\" + str(NUM_EPOCHS) + \"_epochs.pth\"\\ntorch.save(vae.state_dict(), fname)\\nprint(\"saved checkpoint @\", fname)\\n\\n# load\\nvae = Vae(x_dim=X_DIM, z_dim=Z_DIM, hidden_size=HIDDEN_SIZE, device=device).to(device)\\nvae.load_state_dict(torch.load(fname))\\nprint(\"loaded checkpoint from\", fname)\\n'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# create our model\n", "vae = Vae(x_dim=X_DIM, z_dim=Z_DIM, hidden_size=HIDDEN_SIZE, device=device).to(device)\n", "\n", "# optimizer \n", "vae_optim = torch.optim.Adam(params=vae.parameters(), lr=LEARNING_RATE)\n", "\n", "# save the losses from each epoch, we might want to plot it later\n", "train_losses = []\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " epoch_start_time = time.time()\n", " batch_losses = []\n", " for batch_i, batch in enumerate(train_loader):\n", " # forward pass\n", " x = batch[0].to(device).view(-1, X_DIM) # just the images\n", " x_recon, mu, logvar, z = vae(x)\n", " # calculate the loss\n", " loss = loss_function(x_recon, x, mu, logvar, loss_type='bce')\n", " # optimization (same 3 steps everytime)\n", " vae_optim.zero_grad()\n", " loss.backward()\n", " vae_optim.step()\n", " # save loss\n", " batch_losses.append(loss.data.cpu().item())\n", " train_losses.append(np.mean(batch_losses))\n", " print(\"epoch: {} training loss: {:.5f} epoch time: {:.3f} sec\".format(epoch, train_losses[-1],\n", " time.time() - epoch_start_time))\n", " \n", " \n", "\"\"\"\n", "# save\n", "fname = \"./vae_mnist_\" + str(NUM_EPOCHS) + \"_epochs.pth\"\n", "torch.save(vae.state_dict(), fname)\n", "print(\"saved checkpoint @\", fname)\n", "\n", "# load\n", "vae = Vae(x_dim=X_DIM, z_dim=Z_DIM, hidden_size=HIDDEN_SIZE, device=device).to(device)\n", "vae.load_state_dict(torch.load(fname))\n", "print(\"loaded checkpoint from\", fname)\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can sample the model" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# now let's sample from the vae\n", "n_samples = 6\n", "vae_samples = vae.sample(num_samples=n_samples).view(n_samples, 28, 28).data.cpu().numpy()\n", "fig = plt.figure(figsize=(8 ,5))\n", "for i in range(vae_samples.shape[0]):\n", " ax = fig.add_subplot(2, 3, i + 1)\n", " ax.imshow(vae_samples[i], cmap='gray')\n", " ax.set_axis_off()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Investigate $z$ with interpulation\n", "\n", "we can create:\n", "$$ z_{new} = \\alpha z_1 + (1-\\alpha) z_2, \\alpha \\in [0,1] $$ \n", "\n", "and see the transition between the 2 images." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "alphas = np.linspace(0.1, 1, 10)\n", "# take 2 samples\n", "sample_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True, drop_last=True)\n", "it = iter(sample_dataloader)\n", "samples, labels = next(it)\n", "while labels[0] == labels[1]:\n", " # make sure they are different digits\n", " samples, labels = next(it)\n", "x_1, x_2 = samples\n", "\n", "# get their latent representation\n", "_,_, _, z_1 = vae(x_1.view(-1, X_DIM).to(device))\n", "_,_, _, z_2 = vae(x_2.view(-1, X_DIM).to(device))\n", "\n", "# let's see the result\n", "fig = plt.figure(figsize=(15 ,8))\n", "for i, alpha in enumerate(alphas):\n", " z_new = alpha * z_1 + (1 - alpha) * z_2\n", " x_new = vae.decode(z_new)\n", " ax = fig.add_subplot(1, 10, i + 1)\n", " ax.imshow(x_new.view(28, 28).cpu().data.numpy(), cmap='gray')\n", " ax.set_axis_off()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what we've learned with T-SNE, to reduce dimention and cluster the outputs :)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 't-SNE of VAE Latent Space on MNIST')" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# take 2000 samples\n", "num_samples = 2000\n", "sample_dataloader = DataLoader(train_dataset, batch_size=num_samples, shuffle=True, drop_last=True)\n", "samples, labels = next(iter(sample_dataloader))\n", "\n", "labels = labels.data.cpu().numpy()\n", "# decode the samples\n", "_,_, _, z = vae(samples.view(num_samples, X_DIM).to(device))\n", "\n", "# t-SNE\n", "perplexity = 15.0\n", "t_sne = TSNE(n_components=2, perplexity=perplexity)\n", "z_embedded = t_sne.fit_transform(z.data.cpu().numpy())\n", "\n", "# plot\n", "fig = plt.figure(figsize=(10 ,8))\n", "ax = fig.add_subplot(1, 1, 1)\n", "for i in np.unique(labels):\n", " ax.scatter(z_embedded[labels==i,0], z_embedded[labels==i, 1], label=str(i))\n", "ax.legend()\n", "ax.grid()\n", "ax.set_title(\"t-SNE of VAE Latent Space on MNIST\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### Thanks" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "**Credits**\n", "\n", "This tutorial was written by [Moshe Kimhi](https://www.linkedin.com/in/moshekimhi/)
\n", "To re-use, please provide attribution and link to the original.\n", "\n", "\n", "some code snips from:
\n", "[kernel-density-estimation](https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html)\n", "\n", "images sources:\n", "https://developers.google.com/machine-learning/gan/generator
\n", "https://towardsdatascience.com/comprehensive-introduction-to-turing-learning-and-gans-part-2-fd8e4a70775" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "rise": { "scroll": true }, "toc-autonumbering": false }, "nbformat": 4, "nbformat_minor": 4 }