{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Maximizing the ELBO\n", "\n", "> In this post, we will cover the complete implementation of Variational AutoEncoder, which can optimize the ELBO objective function. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/fashion_mnist_generated.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Ellipse\n", "from IPython.display import HTML, Image\n", "\n", "tfd = tfp.distributions\n", "tfpl = tfp.layers\n", "tfb = tfp.bijectors\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)\n", "plt.rcParams[\"animation.html\"] = \"jshtml\" \n", "plt.rcParams['animation.embed_limit'] = 2**128" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prior Distribution\n", "\n", "$ \\text{latent variable } z \\sim N(0, I) = p(z) \\\\\n", " p(x \\vert z) = \\text{decoder}(z) \\\\\n", " x \\sim p(x \\vert z) $" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Approximating True Posterior distribution\n", "\n", "$ \\text{encoder }(x) = q(z \\vert x) \\simeq p(z \\vert x) \\\\\n", " \\begin{aligned} \\log p(x) & \\ge \\mathbb{E}_{z \\sim q(z \\vert x)}[-\\log q(z \\vert x) + \\log p(x \\vert z)] \\quad \\leftarrow \\text{maximizing this lower bound} \\\\\n", " &= - \\mathrm{KL} (q(z \\vert x) \\vert \\vert p(z)) + \\mathbb{E}_{z \\sim q(z \\vert x)}[\\log p(x \\vert z)] \\quad \\leftarrow \\text{Evidence Lower Bound (ELBO)} \\end{aligned}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample Encoder Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "latent_size = 2\n", "event_shape = (28, 28, 1)\n", "\n", "encoder = Sequential([\n", " Conv2D(8, (5, 5), strides=2, activation='tanh', input_shape=event_shape),\n", " Conv2D(8, (5, 5), strides=2, activatoin='tanh'),\n", " Flatten(),\n", " Dense(64, activation='tanh'),\n", " Dense(2 * latent_size),\n", " tfpl.DistributionLambda(lambda t: tfd.MultivariateNormalDiag(\n", " loc=t[..., :latent_size], scale_diag=tf.math.exp(t[..., latent_size:]))),\n", "], name='encoder')\n", "\n", "encoder(X_train[:16])\n", "``` " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sample Decoder Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Almose reverse order of Encoder.\n", "\n", "```python\n", "decoder = Sequential([\n", " Dense(64, activation='tanh', input_shape=(latent_size, )),\n", " Dense(128, activation='tanh'),\n", " Reshape((4, 4, 8)), # In order to put it in the form required by Conv2D layer\n", " Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),\n", " Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation='tanh'),\n", " Conv2D(1, (3, 3), padding='SAME'),\n", " Flatten(),\n", " tfpl.IndependentBernoulli(event_shape)\n", "], name='decoder')\n", "\n", "decoder(tf.random.normal([16, latent_size])\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prior Distribution for zero-mean gaussian with identity covariance matrix\n", "```python\n", "prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ELBO objective function\n", " One way to implement ELBO function is to use Analytical computation of KL divergence.\n", " \n", "```python\n", "def loss_fn(X_true, approx_posterior, X_pred, prior_dist):\n", " \"\"\"\n", " X_true: batch of data examples\n", " approx_posterior: the output of encoder\n", " X_pred: output of decoder\n", " prior_dist: Prior distribution\n", " \"\"\"\n", " return tf.reduce_mean(tfd.kl_divergence(approx_posterior, prior_dist) - X_pred.log_prob(X_true))\n", "```\n", "\n", "The other way is using Monte Carlo Sampling instead of analyticall with the KL Divergence.\n", "\n", "```python\n", "def loss_fn(X_true, approx_posterior, X_pred, prior_dist):\n", " reconstruction_loss = -X_pred.log_prob(X_true)\n", " approx_posterior_sample = approx_posterior.sample()\n", " kl_approx = (approx_posterior.log_prob(approx_posterior_sample) - prior_dist.log_prob(approx_posterior_sample))\n", " return tf.reduce_mean(kl_approx + reconstruction_loss)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculating Gradient of Loss function\n", "```python\n", "@tf.function\n", "def get_loss_and_grads(x):\n", " with tf.GradientTape() as tape:\n", " approx_posterior = encoder(x)\n", " approx_posterior_sample = approx_posterior.sample()\n", " X_pred = decoder(approx_posterior_sample)\n", " current_loss = loss_fn(x, approx_posterior, X_pred, prior)\n", " grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)\n", " return current_loss, grads\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training Loop\n", "\n", "```python\n", "optimizer = tf.keras.optimizers.Adam()\n", "for epoch in range(num_epochs):\n", " for train_batch in train_data:\n", " loss, grads = get_loss_and_grads(train_batch)\n", " optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test\n", "```python\n", "z = prior.sample(1) # (1, 2)\n", "x = decoder(z).sample() # (1, 28, 28, 1)\n", "\n", "X_encoded = encoder(X_sample)\n", "\n", "def vae(inputs):\n", " approx_posterior = encoder(inputs)\n", " decoded = decoder(approx_posterior.sample())\n", " return decoded.sample()\n", "\n", "reconstruction = vae(X_sample)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Review of terminology:\n", "- $p(z)$ = prior\n", "- $q(z|x)$ = encoding distribution\n", "- $p(x|z)$ = decoding distribution\n", "\n", "$$\n", "\\begin{aligned}\n", "\\log p(x) &\\geq \\mathrm{E}_{Z \\sim q(z | x)}\\big[−\\log q(Z | x) + \\log p(x, Z)\\big]\\\\\n", " &= - \\mathrm{KL}\\big[ \\ q(z | x) \\ || \\ p(z) \\ \\big] + \\mathrm{E}_{Z \\sim q(z | x)}\\big[\\log p(x | Z)\\big] \n", "\\end{aligned}\n", "$$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Sequential, Model\n", "from tensorflow.keras.layers import Dense, Flatten, Reshape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Import Fashion MNIST, make it a Tensorflow Dataset\n", "\n", "(X_train, _), (X_test, _) = tf.keras.datasets.fashion_mnist.load_data()\n", "X_train = X_train.astype('float32') / 255.\n", "X_test = X_test.astype('float32') / 255.\n", "example_X = X_test[:16]\n", "\n", "batch_size = 64\n", "X_train = tf.data.Dataset.from_tensor_slices(X_train).batch(batch_size)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/chanseok/anaconda3/envs/torch/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py:346: calling MultivariateNormalDiag.__init__ (from tensorflow_probability.python.distributions.mvn_diag) with scale_identity_multiplier is deprecated and will be removed after 2020-01-01.\n", "Instructions for updating:\n", "`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.\n" ] } ], "source": [ "# Define the encoding distribution, q(z | x)\n", "\n", "latent_size = 2\n", "event_shape = (28, 28)\n", "\n", "encoder = Sequential([\n", " Flatten(input_shape=event_shape),\n", " Dense(256, activation='relu'),\n", " Dense(128, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(32, activation='relu'),\n", " Dense(2 * latent_size),\n", " tfpl.DistributionLambda(\n", " lambda t: tfd.MultivariateNormalDiag(\n", " loc=t[..., :latent_size],\n", " scale_diag=tf.math.exp(t[..., latent_size:])\n", " )\n", " )\n", "])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pass an example image through the network - should return a batch of MultivariateNormalDiag\n", "\n", "encoder(example_X)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Define the decoding distribution, p(x | z)\n", "\n", "decoder = Sequential([\n", " Dense(32, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(128, activation='relu'),\n", " Dense(256, activation='relu'),\n", " Dense(tfpl.IndependentBernoulli.params_size(event_shape)),\n", " tfpl.IndependentBernoulli(event_shape)\n", "])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pass a batch of examples to the decoder\n", "\n", "decoder(tf.random.normal([16, latent_size]))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Define the prior, p(z) - a standard bivariate Gaussian\n", "prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The loss function we need to estimate is\n", "$$\n", "-\\mathrm{ELBO} = \\mathrm{KL}[ \\ q(z|x) \\ || \\ p(z) \\ ] - \\mathrm{E}_{Z \\sim q(z|x)}[\\log p(x|Z)]\\\\\n", "$$\n", "where $x = (x_1, x_2, \\ldots, x_n)$ refers to all observations, $z = (z_1, z_2, \\ldots, z_n)$ refers to corresponding latent variables.\n", "\n", "Assumed independence of examples implies that we can write this as\n", "$$\n", "\\sum_j \\mathrm{KL}[ \\ q(z_j|x_j) \\ || \\ p(z_j) \\ ] - \\mathrm{E}_{Z_j \\sim q(z_j|x_j)}[\\log p(x_j|Z_j)]\n", "$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Specify the loss function, an estimate of the -ELBO\n", "\n", "def loss(x, encoding_dist, sampled_decoding_dist, prior):\n", " return tf.reduce_sum(\n", " tfd.kl_divergence(encoding_dist, prior) - sampled_decoding_dist.log_prob(x)\n", " )" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Define a function that returns the loss and its gradients\n", "\n", "@tf.function\n", "def get_loss_and_grads(x):\n", " with tf.GradientTape() as tape:\n", " encoding_dist = encoder(x)\n", " sampled_z = encoding_dist.sample()\n", " sampled_decoding_dist = decoder(sampled_z)\n", " current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior)\n", " grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)\n", " return current_loss, grads" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-ELBO after epoch 1: 8990\n", "-ELBO after epoch 2: 8858\n", "-ELBO after epoch 3: 8782\n", "-ELBO after epoch 4: 8820\n", "-ELBO after epoch 5: 8716\n", "-ELBO after epoch 6: 8664\n", "-ELBO after epoch 7: 8727\n", "-ELBO after epoch 8: 8667\n", "-ELBO after epoch 9: 8810\n", "-ELBO after epoch 10: 8675\n" ] } ], "source": [ "# Compile and train the model\n", "num_epochs = 10\n", "optimizer = tf.keras.optimizers.Adam()\n", "\n", "for i in range(num_epochs):\n", " for train_batch in X_train:\n", " current_loss, grads = get_loss_and_grads(train_batch)\n", " optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))\n", " \n", " print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Connect encoder and decoder, compute a reconstruction\n", "\n", "def vae(inputs):\n", " approx_posterior = encoder(inputs)\n", " decoding_dist = decoder(approx_posterior.sample())\n", " return decoding_dist.sample()\n", "\n", "example_reconstruction = vae(example_X).numpy().squeeze()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot examples against reconstructions\n", "\n", "f, axs = plt.subplots(2, 6, figsize=(16, 5))\n", "\n", "for j in range(6):\n", " axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')\n", " axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')\n", " axs[0, j].axis('off')\n", " axs[1, j].axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the model has lack of reconstruction from grayscale image, So using mean for reconstruction gets more satisfied results." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Connect encoder and decoder, compute a reconstruction with mean\n", "\n", "def vae_mean(inputs):\n", " approx_posterior = encoder(inputs)\n", " decoding_dist = decoder(approx_posterior.sample())\n", " return decoding_dist.mean()\n", "\n", "example_reconstruction = vae_mean(example_X).numpy().squeeze()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot examples against reconstructions\n", "\n", "f, axs = plt.subplots(2, 6, figsize=(16, 5))\n", "\n", "for j in range(6):\n", " axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')\n", " axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')\n", " axs[0, j].axis('off')\n", " axs[1, j].axis('off')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Generate an example - sample a z value, then sample a reconstruction from p(x|z)\n", "\n", "z = prior.sample(6)\n", "generated_x = decoder(z).sample()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4sAAACNCAYAAAAeou/jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIXUlEQVR4nO3d0XIiOQwFUNjK///y7FNqdhEEocjdtvuc14GGSmsMKnPl+58/f24AAADwX/+c/QYAAACYj2YRAACAQLMIAABAoFkEAAAg0CwCAAAQfL35d6NSr+1+wGuosTfu93gbNppirMZO8FhTz+op85hFqLEXnq0tj7ru+5GvdQI1lpSpg2dG1uEidafGGO1ljdlZBAAAINAsAgAAEGgWAQAACDSLAAAABPc3wV5h12sTqE7qHN5g+Ei7Jf84lUEQmVp59bjK+1mk7tTYrT5YJCNbd6NeawJq7Da2xp55rIWu11djXJQBNwAAAORpFgEAAAg0iwAAAARfZ78BWFElG9GZ+cpkNSbNXZBUyeNUMzubH5jOrS9X2JlPHJU5Y23ZGrtYvh9OY2cRAACAQLMIAABAoFkEAAAg0CwCAAAQGHADmxLsX1vX8JEuhijxG2plf13DjzoHdVWGKKlV+D87iwAAAASaRQAAAALNIgAAAIHMIhRkchCZ3EPXIdQyFmvJ5EnPvqeVrI9c4zy6Mq+P1+nMk1XfE3Pq+jx7ploHsvvwe3YWAQAACDSLAAAABJpFAAAAAs0iAAAAwf1N2FcS+NrGpdX/2qLGqocBdx1ivHBoX429MHIww8UGK6mxF0YOJKlYpJ6eUWMvHFlj1c/TRepOjTHayxqzswgAAECgWQQAACDQLAIAABB8nf0GYFeZzNnIg6odRryOkQedj7zvamxes+UR2V/X51nneth1bbgyO4sAAAAEmkUAAAACzSIAAACBZhEAAIDAgBtoUBk48+p5lQB+9trMKVMHRw9vUD9rO3Kwx8h1LPNazKFaB9X73nVt4Gd2FgEAAAg0iwAAAASaRQAAAALNIgAAAIEBN3CQakA/M9DB0If9jbzHmdpUY7zStY5VB5aozXV01QG8Yo3oZ2cRAACAQLMIAABAoFkEAAAgkFmEQbpyFw6q3k8lD1itp2p+Q02tLVNjR2bD1CG3W2+erLJuymfvz/3rZ2cRAACAQLMIAABAoFkEAAAg0CwCAAAQGHADgzyGrB00zLeza8NAh/2NqrFsrWRer/IYtbq2zkFLmdrI/D9QU/AzO4sAAAAEmkUAAAACzSIAAADBJTKLfqPOGWY7FJu1VLI21bXOGrmfIzOK1jFut3pOdWSGW23OaeR98dnVz84iAAAAgWYRAACAQLMIAABAoFkEAAAguMSAG+iWCWePDO07aHgvXcOPqtdRK/xGV90ZRrK/zFpTXY+sdesYeR+qdeB71Gt2FgEAAAg0iwAAAASaRQAAAIJLZBY7f/8uK8btVrvHnXUo27OX6jpSrYOurJi1bm2ZOqjmf+QRr6eaXX18jHWFb5XaGJndv+p3fjuLAAAABJpFAAAAAs0iAAAAgWYRAACA4BIDbjKy4ftK2Paqgdir6xoaolauqetQ8676UYdrqXw2ZQZDZOuga7CJwTh7qdTF7ZYbosR+utaxjJGD5lavVTuLAAAABJpFAAAAAs0iAAAAgWYRAACA4OMBN10DXh51hj8rgfiR4dOjg61XCNse6ez6NURpf9XBIl0DQjLPs66sLbNGjBwmU60xNXU9nUOUjhwCdlVnf+c+cohbV8+Tfd6oQWXP/PRadhYBAAAINIsAAAAEmkUAAACCjzOLj2Y8MLfyG/Wjf8felSPyW/vxVs1zqZW9VA+zrlIvaxtVL9UD07tyPOqSb5kam/E76spGfp/u+h7eufZVXn/kd/euv9Gnz7GzCAAAQKBZBAAAINAsAgAAEGgWAQAACD4ecNN1IGVnuLNy7WrYdFTYfuQBpZnQt6EBnxk1vKGTezqv2YYudL0fNTePrjVq5GfcbP8P2J816ndG/v1GDsHpet8jB/x0PWbEwE47iwAAAASaRQAAAALNIgAAAMH9ze9Y/bj72o4IlCxZY5WsTWdmp5JHmjSrsX2NdWUcqo8ZadKaerR9jT0zWw51tvfT7JI1llH5bOo81D1j0pp6pMYmMCIPOJGXNWZnEQAAgECzCAAAQKBZBAAAINAsAgAAEHyd/QZGcOA8MzJ85JqqB+1m1rHZDjXfPPy/lK7hR9XP09lqk3kdWWPWI37jqvVjZxEAAIBAswgAAECgWQQAACDYMrN41d8Us55qVi1zHdaxS75LHfKt6xB1NcUrXTUG/MzOIgAAAIFmEQAAgECzCAAAQKBZBAAAINhywA10yhw03nXg9bNrH3kdzpEZdDRyUINa2U9mUNaRwz8MGuF2q691XcNsfFbC5+wsAgAAEGgWAQAACDSLAAAABJpFAAAAAgNu4I2RgyGqoX2B/LVl6qdriFJG5f08e566nMeR9fNM14AdNba2Sh1m77GhSXAMO4sAAAAEmkUAAAACzSIAAACBzCI06MrjVK8tx7OWs/NkFSu8R+bVmVVjL9X7PjIPCfxlZxEAAIBAswgAAECgWQQAACDQLAIAABAYcANvVAfVdA0x6bq2YP+8Vrg3K7xH5lFZo6xj+6nc95H3OPN5CvyfnUUAAAACzSIAAACBZhEAAIBAZhEaPMs8nH2IuRzGHKoZmUr9dNZhpX7kgeaVqY2R61j12upnL533s2uNBH5mZxEAAIBAswgAAECgWQQAACDQLAIAABAYcAMFmQOnK4955tnzKoMpBPvP0TXMJnOPRw5VytSvGptXdY068trqZz9dn4OV13rGEC74nJ1FAAAAAs0iAAAAgWYRAACAQGYR3sjkGToPs+46KFsOYx1ducZOR78e56uuGV3Pkyfbz8jsfGXdVE/wOTuLAAAABJpFAAAAAs0iAAAAgWYRAACA4C7sCwAAwCM7iwAAAASaRQAAAALNIgAAAIFmEQAAgECzCAAAQKBZBAAAIPgX9oQnRXOPjCgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Display generated_x\n", "\n", "f, axs = plt.subplots(1, 6, figsize=(16, 5))\n", "for j in range(6):\n", " axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')\n", " axs[j].axis('off')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Generate an example - sample a z value, then sample a reconstruction from p(x|z)\n", "\n", "z = prior.sample(6)\n", "generated_x = decoder(z).mean()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4sAAACNCAYAAAAeou/jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlXUlEQVR4nO3dW29V19XG8ZEeQsAYGxtjG4cz5hCUplHVQ6q0aaO2qtSbKu1VpfZz9HP0g/QilapIiar0oFSpmjQhJYSDccBgGzAnEyANPbxX1dv5jAfvwcbA3ub/u5uLuddeh7Hm2guvMeYT//nPfwIAAAAAgP/1uUe9AQAAAACA3sPDIgAAAAAg4WERAAAAAJDwsAgAAAAASHhYBAAAAAAkX+jw7z1XKlWrty4vL6c+r776atM+efJk6rN9+/aO33Xnzp20bN26dU379u3bqc/p06eb9vXr11OfX/3qV017YGCg4/Y8Ak88hO9YlRhzVX112aeffpr63Lx5My3T83XmzJnUZ2FhoWnPzc2lPuPj4017ZGQk9fn3v/+dlul2njt3rmOfw4cPpz4HDhxo2lu2bEl9Nm3alJY99dRTaZn63OdW7f+Z+ibGSl9k4vBf//pXWvbJJ580bY2niBxTMzMzqc8///nPpr1+/frUx50rjfsnnsinYWJiomnv27cv9dm1a1fTduPYF76QbzPu+x6gvokxFz86Rrj7krsPXbp0qWm7ceTatWsrtiMiNmzY0LSffPLJjtsYEXHr1q2m/dlnn6U+ukzHzIiIPXv2NG2Nywg/jul2uutAl91HXPZNjKFvEWP3oTLzg+tTGRMe8v3sQbrrjvCXRQAAAABAwsMiAAAAACDhYREAAAAAkPCwCAAAAABInuiQ9Nlzya6aSP/222+nPq+88krTvnr1auqjRSe0UESET+TX41VJiHXrfu2115r29773vdRnFYuIdKtnE6r1uLtjfPny5abtYuVvf/tbWnb8+PGmvbS0lPpoYSVXGOLEiRNN251PF2OuEI/SwiI7d+5MfbTow/T0dOrjipZoLE5NTaU+WgTnPmK1Z2OsQsejixcvpj7vvfdeWvbuu+827WPHjqU+s7OzTfvvf/976qOFTVyhD1f0RuPXjWMaG4cOHUp9NKaee+651OcrX/lKWjY5Odm0XRGcVdSzMabH3RVD0rHlyJEjqY+LMY2X+fn51OfChQtNW4vSROT7mRuzPv/5z6dlOo65ce2LX/xi0x4dHU19tKDNwYMHUx9X4OuFF15YcT0RuWBdpQjOXfRsjGHNIMaKKgUPq8VsKoVxuvn90+13VdfVJQrcAAAAAADqeFgEAAAAACQ8LAIAAAAAkgeaKPIg6Du958+fT300j8jl7CjNXYjw7wHrut1Ew5rP5XLeNH/u5ZdfTn16IGexJ7j3uDVHUXNvIiJeffXVFdsREYuLi2mZxsvY2Fjqo7k1LldNt9vlAzkbN25s2lu3bu34/W5S7Bs3bjTtd955J/V566230rKTJ0827Z///OepzzPPPNO0XR7TGpqoNiJ8HGr8/PrXv059Pvroo7Rsbm6uabsJ0zXXWuMiImJ4eLhp6wTqET6fTHPF3PineY0ur1LHNp0IPiJiZmYmLfvBD37QtF3O2Vob/1z86HXr8vr++Mc/Nu3XX3899Tl79mxadv369abtxh+NA3ev1OvYXdcu11L7DQ4Opj537txp2i73+x//+EfTdvuxsLDQcdlPf/rT1EdzZ/V4ROTzttbGNaCfVPL63O8hHaO6HcccvVdVttHdl53KdmvO/4MYo9bW3RgAAAAAsCp4WAQAAAAAJDwsAgAAAAASHhYBAAAAAEnfFbjRZE9XPEILM7gJnysJqO5zWlhFE/QdN2G8TqLc7WScjwOXdPzJJ5807ePHj6c+WtDGrWfbtm1pmSYeu8IiAwMDTdvFgSYZDw0NpT4uEVuXuXVrIQYXY/r9rlCPKyjx4YcfNu033ngj9RkfH2/argjPA55o/aFz8fPmm2827d/97nepz5UrV9IyLcjkCpsodzw1Dtw44opwaYxViq+4/dDr0MWqFlqKyIVUduzYkfq462Wt0WPsCrVoTGkBqoiImzdvpmV6blzxFi3MoMXZIvI4cvv27dTHxa9+zo11Ota6Phq/en9363Gfm56eTn10THTHCEBv0/tX5XeV4+6VlWIx2sfdT3WZG7Mq92FHx203jt1v0Rv+sggAAAAASHhYBAAAAAAkPCwCAAAAAJK+SyrSd4pdrpqq5B66fCS3TN8fdn30fWH3rvDs7OyK2+PW87jQ97bdsdH8KZ24OiLntrgJpysTlrsJ5/W8b9q0qeO6Xay479ft1EmpI2oTlus2ue93OWb6fUePHk19dBL5zZs3pz66b/0+mbVOch4R8fvf/75pnzhxIvVx50pzzNx50HHLrUevDZdzUZn81+W86XXoxiP9vvn5+dTH5SzqNr388supj8Zvv8dPJZ/U3c8+/vjjpu1yR11s6PFy46jez9xYp9vochbdunVdLsZ13W4c1T6aixlRyxHS4xgRcevWrabtcjYBrD3uXlmpQeK431ZK87rd79FK7Qn3XfqbbXh4OPXR3xP3ej/lL4sAAAAAgISHRQAAAABAwsMiAAAAACDhYREAAAAAkPRdgRtN5HTFE5RLftck0eoEmapSPGLdunVpmSbSVibefFxUJqWfm5tr2q4Iiyb0uvNZKczgikdosQ+3ni1btjRtVxjCnXctFnPt2rWO2+j27eLFi03bxarb7g0bNjRtLcYUEfH222837d27d3fcxsq10kv0mGpRn4iII0eONG03Hrn40Zh2yeY6RrlzrOvWcxfhC9PoeV+txH5XNMDF7+nTp1dsR0Ts37+/afdbgRs9X5UCN8eOHUt9NO7cNetirDJRdeUep1yMOTr+ukJdOkZ0UygionZMXPEpLZbjCnXpMeq3OATWOr1G3TiyWkUt3Vij43ilGJ37LldgTD+nRbnuti6lY5sr/LkS/rIIAAAAAEh4WAQAAAAAJDwsAgAAAACSvstZ1Pd33UTZqvIessvncO8PV3J79F1gt27dj8ok648rlwe1sLDQtC9cuJD66HF373o7+m63y/mq5K1obk1l4mj3Obf/Gj8uxnTS10peXETE5cuXO36/Th7u8tJ0Yth+y1lUZ86cScsWFxebthsfKuOPywPrZkxw57jbfOzKZ3R/q9+lE8vPzMykPnrc+n2MdMdBxyQXY8vLyx3XXcl5dxPea9y5bew2R8/l9qhKXmMl18fFho5JZ8+eTX2Wlpaa9vbt2++6rQBqKvcXN65UPlf5HVXJR3T3apcPrdupec6uj7sP6rLq7zGtdeGeeebn55u227eBgYGO27jSWN/fd18AAAAAwAPBwyIAAAAAIOFhEQAAAACQ8LAIAAAAAEj6rsCNJoBWkv9d0qYWr3GJ9VqgIyInzbtE0kohDy1IwES//0+PjZvMXgvcuAIHWphGC7dE+OItWvTBTV6qy1zStRaYccnT69at67jMFaaoFMHRmHb7cfPmzbTsqaeeatobN25MfbQwxPnz51MfLRbhCgX1Mo1DLeoTka//SsGiCF88S1WS/TXG3Dl2Y4teL+67tI8b6/Scun11dIy8ePFi6qOFTfotflSl0M/p06dTH722K0VgIvJ17OJA191N4aO7fa6yrkphHo0VV2Dixo0baZkWKnNjpI5jbnv6vTDXWtFtbFbw++vh63bMqHzOjYc6brpiMm7dev27+1ClCJcu0/E5olY8x91j9XNHjx5NffS3wr59+1If9/vhv/jLIgAAAAAg4WERAAAAAJDwsAgAAAAASHhYBAAAAAAkfV/gxiWta/EIV/xEE9ld0upLL72Ulv3mN7/puI26bpc0r8mmbhsfV5pk7ArTaEGD0dHRjn1cYrBb5talXGEapcnCrmCSi41K8SPtUynCoMU0InKhoIic+Dw+Pp76DA4ONm0tFBGRC6K4Y9bLhQX0GJ87dy710X1057NS4Mb10cJCleT/6jHW7XQFdyr7psvcfrjY1HW7Ajc6tuuY2W/c8dPrxhVqqdzzXNEbLaBw69at1EevbXf+dLtdEYRKbLh7bKV4jy5zRblc8SUtVOb2/9KlS03bFcbo98JK/UDHNjfWVYqGuGtDuXFEY7rb+1Kl+Irro9/XS78HK/ed1bqPd1vgRscaFyt6TF2sVO5V7tzo59w4ouOPK2bjxiilRTYjcoHB2dnZ1Ee/b3JyMvUZGhq66/f2TkQCAAAAAHoGD4sAAAAAgISHRQAAAABA0nc5i5q/4N4NVu4d50o+1be//e207LXXXmvaLsdD399270/rhMG9nLv1sFXy8TT/z+UZ6rvlbhJUl/+iMebyaDReKpNL3759Oy1z+TCaU+Hebdf4de+6V96jd8dWc+U0ViMixsbGmrbbD5dH1E/0unU5i3req/us8ePOjS5zuRKaa1PNw6iMNzpuunUPDAw07UrOmVv31atXUx+9XlbKp+hFuo9ujFheXl6x7VSvY5eHqvT+5eJXr20XB+4eV4lfXZf7fv2cu1e7z+m63VivubKV3xO4u0rOmbsONH7cvdLlNWsevotDvTZ27dqV+ujvicq1E5H3xY11ui+VHPbK75mHRc+pu3dU8jK7+a7quvW8uxjTMcKNGZX8fhdjuk1urDl79mzTdrmHjuaxuzFKa09oLnZExJEjR5r297///dSHnEUAAAAAwD3hYREAAAAAkPCwCAAAAABIeFgEAAAAACR9V+BGE1BdYY1KQq4mqboiHl/72tfSMk2ud5MoV4pHaIGQXpqE9VGrTGa/Z8+epu2SlTXp1yUUu4ISlfOn58slPet+VGI1wif3K/d9SvfXTabtJmadmppq2q4ggBbB2b17d+rjCgr1k8rE8RoHLvncXdtaGMYVeKkk7ev3uWPu4k6LLLjYqMSY7ptOhB7hk/017i9fvpz66Ng6MTHRcXt6mRujdB8rxbTc+LBp06a0TM+f+36Nn24LLVUKO7l16+e6nQC8MlG328alpaWO29htsY7HQSV+9Pi5+7COfx9++GHqc+LEibRMx40tW7akPlq0w33/vn37mrb7PejiTtelk6NH1CZj37p1a9P+xje+kfqstd+IlUJHrp/ro8tcESGNQ1fEyMWv3hvdeHD9+vWm7caaxcXFpq0FbyJ8ET0t0OTWfeHChabtCtXMzc01bVcEZ+/evWnZf62t6AMAAAAArAoeFgEAAAAACQ+LAAAAAICk73IWNSfH5cjou+Xu3WTN+zh48GDq43JkNNfQ5droO9buXfcvfelLTXutvY9+P/R4uXyc/fv3N23NoYvI74i7iatPnTqVls3Pzzdtl5eqcbd+/frUR3U74bPLQ6vkI+oy1+f5559Py5599tmmfeDAgdRHrw3NuYjwuXL9RCf1dvmtet26CcvdJOKa9+DyEPRzLldDrxWXB+KWKXdt6DIX47qN7rtc/ojm5rlrTPOYXK5IJb+4V7jt1/wll7Oon3PH2N0/uslHdPmlOv5Uj7med3cfruS86r3axao7bppHNDIykvpo3o67DiuTgvd7HFb20Y1tGi/uPOh1/NFHH6U+H3zwQdN2Y63L9dN40RzUiJzPdeXKldTn5MmTTdvlPjq6bx9//HHqo9eBy5ncuXNn0z58+HDq4/IoHwaN7WrOcKc+LlbcGKXf5+6DGptuPNTPuf1wY5T+bnJ9dF3unqdcjLucxTNnzjRt93tY783uN4fW/tDfx53whAIAAAAASHhYBAAAAAAkPCwCAAAAABIeFgEAAAAASd8VuKkkxGuyqUta1yR5neQ9wk9sqQVujh8/3nEbXUECTaDupwT5B02PnyvwMjo62rRdHGgfdz5d0ZeFhYWO26ifc9uoceeSt93nKvGjk7q7whT6/a7Aw7e+9a20bHp6umnv2LEj9dEk78HBwdRH96PfYlwnU3aToSu3j65YgkuS77QuF6uatF8pdOLW7YoG6BhZmTDZxYFLttft1GMd0f8FbnR73fGrFELQY+wKR7kCUzomuHFEj59bd2VSaBebOrZVCi1VCjS5GNNiNhF5u9026jVducZ7STeFRVwc6hih116En3BeC3K4AjN6bbtzrHHn7tWO+z6lBW5crGixFfd7cNu2bWmZHjdX4EbHejceah8Xh4+qwE03XIzpMlcwycWGHmM3/uj9xN0HdZkrlOXuJ5V161jnxnXtMzk5mfq44kd6nNzvMd3up59+OvXR4+/G0ZXwl0UAAAAAQMLDIgAAAAAg4WERAAAAAJD0Xc6ivtuuE01G1HJtdBLLqampjt8Vkd8Fdu8vVyYjd+8Uw3Pvlm/cuLFpuzwAPQ8uv8Pl8bncFqXv0bt31F2umKrk8bj1aM6ie/9f495dB9u3b0/L9F36Sq5nZaLaflOZMF25c+Um0dVxw8WmrsvFeCUfycWz7ksld7aS6+TGYzceViaD11yjSn5WL9Htdbk2GmPuGOt63HjorlHNFauMa27dOka569rtmy5z36/jRmUcdduotQQi8uTr7vrVuHO5Yrr/7vsfhsoYUcn5dJPbv/fee03b1WLQ6zEi4vLly03bXf8TExNN213rlWPqxijdf5f7rFz8aqy4bTx16lRa5molqAMHDjRtraUQka/7Si7zo+LiUJe5ONRz4yald+vW2KjkGro+lZzFyj3G7Zvem/X5wm2jXjsRfozUdbvfbHqP3blzZ+pz48aNpu1qKayEvywCAAAAABIeFgEAAAAACQ+LAAAAAICEh0UAAAAAQNJ3BW40KbVSmMEVndDE5Mqk4hG5sIpLdtViI2491UlnH0eaCOwS0jWh1x1j/ZwrAuNUCkHoeXeTqWpispsM2E20q0nWlYISLsb1GLkCD65oihZkcX26STrvZS6xXSemdvFTSezXYkQRuYCCK2xSKfpQ6ePOg8a4K1CiMeViTK87V8zHqRTm0ePvju2jKjaiKkUf3DHWIg+VfXTn090H9RxXxj+3HzqOuP2ojBHuc5V7ZaftudvndF9cYRct+uCKblSKkD0M7vydO3euabvzp9fWzMxM6vP66683bXc83fGrFHjROHT3V703uu9yy7QQjCseo+OInvOI/HusMmZGROzZsyctU3qM3HnU4kGVYmoPQmUcc/Qacfuo58od48o91o0j+vvdxW9lHO22wJfumxtHtKCN+z147NixtGzHjh1N2xWm0d+R+pwSkX+HVO/V/8VfFgEAAAAACQ+LAAAAAICEh0UAAAAAQMLDIgAAAAAg6bsCN1pk4fbt26lPpXiDJjSPjY2Vvn9ycrJpu+RfTYh1ybauIAC8SvECRwshuMIIblll3Zps75LfNVZd8rT7Lk2gdjGuMebWrcvcNq5fvz4te/LJJzuuu1IEqJ8K3DhLS0tN2yXfVwqsuGRzLbrg6PFzsaKx4QpOVIpwVYrXVL7fFfNxMaYFftw1roUoeqXQSLfcMe5mHHFjlotDPX7u3Og5dTGuBVLcuXIFOfR8uT66Lrcfev92x0jHrIi83e7437p1q2m7QmWV+8HDoAUyIiLef//9pu0KRelx/8tf/pL6/PnPf27ae/fuTX0qhZVc8RyNAxe/Oh64gk0uNrSwyJUrV1KfxcXFpj0/P5/6XLp0qWm74h/u+pmdnW3aLg41ptxvP+1z8eLF1OfQoUNp2cNQGXc17vR8RuRj7M6Vu9ZGRkaaduX6rxS+dNz1o3HvjoeOLRoXERF/+tOfmvZvf/vb1Oeb3/xmWjY1NdW0h4eHUx+9Ntw26jJ3jFbCXxYBAAAAAAkPiwAAAACAhIdFAAAAAEDSdzmLmv+i7/M67l17ze1xOYvuc+Pj4x376HvPro+bWBN1ekzd++i6rDKZakRtwnDl8mE0N6M6Ca3mNLht1PfN3XoqubuVPDx3bPs9H7FCj3slD8BxOYuVCYp13S6fQ8dDF6vuc3reK7nXLh9H86FczqTLjdAcHZeHUtnGXqbn1I0RlXOs59TlTrll+n2VOHBjhJ53N2Y4um+VdVfyyl3uWmVy+Epet5v4vVfi7tSpU2nZ0aNHm7aLAz1fLmdOf9e4CcPdcdDr38X48ePHm7Yb63Tdbjxw8aOTn7ucU82fc310m1zOoNumSn0BPbZujNbPXbhwIfV5GCr3uMq9wsWB7tPCwkLq4+4V+n2Dg4Opj8Z9Ja/Qcb91dF3uHGu8vPPOO6mP7v++fftSn+eeey4tm56ebtpu/zUf9IMPPkh9nn322aZ9r7/h+MsiAAAAACDhYREAAAAAkPCwCAAAAABIeFgEAAAAACR9V+BGE+I1edj1cZMBa7Ly6Oho6ft37ty54ndF5IRyl0jqEvnxYFUmFXfcudLJnF0caLEEl2DtkuZ13Y5ud6UIj+Nis1IYaK1xseGKXSg9fy5B303wXKHntDJhuYvVyvjjikfoManEgbsONm/enJZpQQB33PR6qRQT6mWVAjMuVvQ4uHGkMuG0o+OIiwO9f7qxxsWm9qsUCKkUmHAFv1xhF40pV9hFt6lSBONR0cnJIyK2b9/etLXgS0TE0tJS03YFXnQ9ro+7/jR+XR8t5nfjxo3UR4uAufOpxbwicvy660c/586x7q8r4uTGPz1ulQJjbqzT47Zr167U52Fw218p9KZFhObm5lKfxcXFpn316tXS92ssuHFEv9+d40pRQDf+aGy4fXvzzTebtiswMzEx0bR/9KMfpT7Dw8NpmTp//nxa9v7773fcRr0Pu+tQt/F/8ZdFAAAAAEDCwyIAAAAAIOFhEQAAAACQ9F3Oor5T7HIlKvkb+o5zNeejMvmny/tQvZwbsVZoroB7H97lB+rn3DnWPu79d+1TeR8+IudvuD4aY+49fn1H3eVn9vIk1I+aHpvKhMEuZ7CSj1eZjNzROHAx5sYa3W6X112JX83jcdeK2/+KyoTfvUyPn9t+zQdyx0r76Pjgvisi3xtdzlUlD1XvlS53zMWPxosbf/WYuHVrHqzbV3fcdJn7raDjpsuH7JVc2b1796ZlmsfoxnjNTXJjvu6j22d3bPT7XF6oLqucBzfWuHOs8evWXRmjdJnG/N1s2LChabtjpNvkxnW9fiYnJ0vfv9pcHpsuc+PY/Px8015YWEh9/vrXvzZtzaWN8OdP64m431p6/FzuvI4tLp7cby3NtXzjjTdSn3fffbdpf/WrX019Xnzxxabt8mtnZ2fTMj2Wuj0REVeuXGnaf/jDH1Ifjal7rXPBXxYBAAAAAAkPiwAAAACAhIdFAAAAAEDCwyIAAAAAIOn7Ajdu8lZNunYJuZUJtx1dV7fFQFySLlaXJrZXCixE5ERol3RdKQigcegS211BAP2c+35d5vqoSmJ9RO8UdHiY3D5r8SOXNK/nSgseRPhzrOOGK1BUKZSl63EFS1xs6Hl345huk+uj++8KbLj91++vFObptwI3GlPXr1/v+Bl3rjTuXKy4+NXxzp0/vQ+58UC/r1LMJiLHootN3Td3X9Tz7q6Lyve761eLXFTvEY+CK8ihk2i77dfYqPRx++xiTPu5c6xx5+JQl1UK1TiV7a4U2KlcB26bKvfOyjEaGhrquJ4HwY3VWrTHXWtarMVdo1pgxR3PwcHBtEwL3FQKXFWKGLlYcUVntMDM8PBw6vOTn/ykaX/nO99JfXbs2NG0XXFFR39T7N69O/V56623mvZLL72U+jz//PNNW49rJ/xlEQAAAACQ8LAIAAAAAEh4WAQAAAAAJH2Xs6jve4+Pj6c+mvfg3hHXd6or+UER+f1hl6Ok63LvWFcm3Mb90fPu8hAq593lOOj5c7k2lThw79Zr3oCLX+1TmZT66tWrqY/b7kpuyONAcxpczpMeP5fP4I6nfs7lT2iMVfKynEoejRuPNKfE9dHcb5cP5fJvNDbdunVdvZI71i13HvScupwdzQPduHFj6uMmEdfj5a513SbN13Z9HHdtqMqk6pXx2F0HLtdqbGysabtxXLfJxVivxJ271nWZOzaqUmeh21oMTjc58JW8vrst62bdq3XcutXNeXwQKteWO58jIyNN210zBw8ebNruXunGP71G3RilY6K71nUcrdQ7cet+4YUXUp9nnnmmaWsusdsmlx/ptltzPd0Yqcfk4sWLqY8efzcer4S/LAIAAAAAEh4WAQAAAAAJD4sAAAAAgISHRQAAAABAsiYL3GhCrkv+r0zQ6WghhsqE7ZXJXLH6NCH9xo0bqY+bGFUTiCsJ8S5+dDJbV4TB0XW7oiFarMLFmE4C7vb1ypUrpW1a6yqTmrtrXZe5pH237m4LeShXtKny/TqOueINuu7KpO5u4vnNmzenZZXJ2DVeXR/dpl4aVyv3GD0ObozQWHFF1a5du5aW6fFyxRv0+9051iJGrphNpeiCi0PdxkoRHLcf7rhpsYzKJNguxnQccMeol+Kuk8q2rub+PKpiLVgd+tvCxYbeT9y1vri42LSrcaHXthtHdZvcNup44NbjnhX0nr5169bUR4tpuXGscq+qFMirFDx0RdCmp6ebtisUtBKuYgAAAABAwsMiAAAAACDhYREAAAAAkPCwCAAAAABI+q7AjXLFEzTZ1iXfawLu8vJy6fv0c67oRaftifBFS9A9V3RAl7kCB65YgktyVprA7GJME7hdH7fdmqxcKexRKXSihSoifNGfSvGVtcadBz0O7jrWZPOhoaGOfdz3VcYRN2Zo8QG3Hvf9+jlXkEDjt1LoqVLoxC2rFA/q5THTxY9yBR00plxhAj0OLg7d2FY5XlrkoRIHrsCN2zeNF1dQQuOwUnTBxaFbpkUn3D1eC0q4whR6TCrnGuhHlWKQbvzRa8IV3NPr0f0ecfdPvae4669SjK5SqMp9bsuWLU3bFaHRcaNSIMqNmW78qYzj+jtWj3VEHrfvtYgVf1kEAAAAACQ8LAIAAAAAEh4WAQAAAABJ3+csuveHR0dHm7absHhgYKBpu/enK/lkLsdC3zF2OR73OiEm7p2+a1/NT9Qcwcq75e69cn0n3OUVuhwvzQmoTF5bydl078O7nEXNLaqsu58mpXYq+Qsun0HPjU78ez/r1mPqJh7XuHOxUhlrXPzqdrtxrDIZs/tcJcdD193LOYvd5qhs2rSpabvxoJLX5z6n2+TyiPR8uT66zI1jbmzRbXL7X8nV1c+568kdf5082x03rXlQyRkF1ioX6/p7wOVH6zXqflfpGOHuZ+76099jn376aeqj2+S2Ub/f3as2bNiQlum44XI2lVu3cuOhG3/1+10NC91ulx9+v88c/GURAAAAAJDwsAgAAAAASHhYBAAAAAAkPCwCAAAAAJK+L3DjiidMTEw07ZmZmdRHE0ldQmylsEel6IVLiCVp/sGrTCbrimZoYQZXvEGL5bgCE5XiI5UJ012yslumNBHc7asr+qMFLB6HSajdPmpCuIsDLbC1ffv21MdNNDwyMtK0XbJ/hY4jbjx09Npw8asx5q4f3X9X4Ef3NSIXFnEFxnTfKkV4eqnQkm6bO356vFzBKY0fV6jAxaYeCzdmaB9XmEHHI3eMXdEbPX+V4j2ueIV+rrr/Opm2Kx6hBTXcOdL9fRzGQzye3H1If7e4a1S5+5COB+5e6cYf3aZuimJF5OtYC2FG+AI3+v3uGFXuO9rHfaZSzNB9To+3Gw/1uN3rvZK/LAIAAAAAEh4WAQAAAAAJD4sAAAAAgGRN5ixqroJ7x1hzE1zORSV/w73jrO8G68TLEb2VW7NWaY6Oy89zeTQaG+4dfY2pSn6ry5lx8avLXK6R5m+5dVdyhty6dbsfhxwdd63rdetypfS4T01NpT4uN2J6erppuxjrZvLfaj6FnlMXG7rM5YrouOnyUNz+64TpLg41n6+Sz/GodHv9a87ntWvXUh/NP3H3HDeOaGy4GKvchyq5lxVuG3Wb3L5pjA0MDKQ+bpmuy93jdV9crg/wuHDxr+NWJT/aXcc6Hrlr1o0t+jlXX0S5e1Ulr899vy5brftQNWdRt7Oyje7Y6v7f63707t0XAAAAAPDI8LAIAAAAAEh4WAQAAAAAJDwsAgAAAACSvi9w45I0dcJnlxCqCbCusIAruqBJopWJ3nu5MMNapsfdTer92WefpWVXr15t2i4ONMnaracymbUrjKFFL5aXl1MfLdbgkrU1ydztv0sEd/uy1rlzowVuKoWqxsfHU599+/alZT/+8Y+btht/tCCTK1Ci57R6jlVlUnW3/zoeHj58OPVx458WvdFrLiLHdKWIzKPi9lHvDVqwJyJicnKyabtzrMfdFTFyx31paalpu+taxzYXB3qMXZ9qsQql53jjxo2pj97P9ZhFROzevTst27ZtW9N217hutyvCU5lMG1gL3L2im/h369FxxBXBcfSe5sYI5bZRl1X2NaK73++VY1T9Paj9KveabrdpJTzFAAAAAAASHhYBAAAAAAkPiwAAAACApO9zFl0ezc9+9rOm7SZDPnDgQNP+4Q9/mPq43BDNSfrFL36R+gwNDTXt7373u6mP5qXh/rj3sTUf5utf/3rq88tf/jItm5uba9ouZ1Anir18+XLHPm497r15nfzdxaEaGRlJy3Tic7ee/fv3p2WaE1R9t7+f6TGPiHjxxRebtssV0M99+ctfTn3GxsbSMs1jdPlkmhvmcsXu3LnTtDXmInweo3LnWJe5nDvN8XJ5KC4PT3PFzpw5k/ocOnSoaU9MTHTcxl6ix+bpp59OffTetGfPntRHr1t3z3Nj2yuvvNK0Xfxo7vX169dTH40fl8Pt7rF6vbjxR+NF750ReRx3fTSv0bl06VLH79+yZUvqo9/XyzEHrLZucvbceKCq+eaaj1fJXXf3wUruZXVZN7rNY+yV31r8ZREAAAAAkPCwCAAAAABIeFgEAAAAACQ8LAIAAAAAkid6ZVJjAAAAAEDv4C+LAAAAAICEh0UAAAAAQMLDIgAAAAAg4WERAAAAAJDwsAgAAAAASHhYBAAAAAAk/werIeoxaadbOwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Display generated_x\n", "\n", "f, axs = plt.subplots(1, 6, figsize=(16, 5))\n", "for j in range(6):\n", " axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')\n", " axs[j].axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What if we use Monte Carlo Sampling for kl divergence?" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "encoder = Sequential([\n", " Flatten(input_shape=event_shape),\n", " Dense(256, activation='relu'),\n", " Dense(128, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(32, activation='relu'),\n", " Dense(2 * latent_size),\n", " tfpl.DistributionLambda(\n", " lambda t: tfd.MultivariateNormalDiag(\n", " loc=t[..., :latent_size],\n", " scale_diag=tf.math.exp(t[..., latent_size:])\n", " )\n", " )\n", "])\n", "\n", "decoder = Sequential([\n", " Dense(32, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(128, activation='relu'),\n", " Dense(256, activation='relu'),\n", " Dense(tfpl.IndependentBernoulli.params_size(event_shape)),\n", " tfpl.IndependentBernoulli(event_shape)\n", "])\n", "\n", "# Define the prior, p(z) - a standard bivariate Gaussian\n", "prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z):\n", " reconstruction_loss = -sampled_decoding_dist.log_prob(x)\n", " kl_approx = (encoding_dist.log_prob(sampled_z) - prior.log_prob(sampled_z))\n", " return tf.reduce_sum(kl_approx + reconstruction_loss)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "@tf.function\n", "def get_loss_and_grads(x):\n", " with tf.GradientTape() as tape:\n", " encoding_dist = encoder(x)\n", " sampled_z = encoding_dist.sample()\n", " sampled_decoding_dist = decoder(sampled_z)\n", " current_loss = loss(x, encoding_dist, sampled_decoding_dist, prior, sampled_z)\n", " grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)\n", " return current_loss, grads" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-ELBO after epoch 1: 8914\n", "-ELBO after epoch 2: 8802\n", "-ELBO after epoch 3: 8799\n", "-ELBO after epoch 4: 8743\n", "-ELBO after epoch 5: 8790\n", "-ELBO after epoch 6: 8716\n", "-ELBO after epoch 7: 8787\n", "-ELBO after epoch 8: 8686\n", "-ELBO after epoch 9: 8650\n", "-ELBO after epoch 10: 8813\n" ] } ], "source": [ "# Compile and train the model\n", "num_epochs = 10\n", "optimizer = tf.keras.optimizers.Adam()\n", "\n", "for i in range(num_epochs):\n", " for train_batch in X_train:\n", " current_loss, grads = get_loss_and_grads(train_batch)\n", " optimizer.apply_gradients(zip(grads, encoder.trainable_variables + decoder.trainable_variables))\n", " \n", " print('-ELBO after epoch {}: {:.0f}'.format(i + 1, current_loss.numpy()))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Connect encoder and decoder, compute a reconstruction with mean\n", "\n", "def vae_mean(inputs):\n", " approx_posterior = encoder(inputs)\n", " decoding_dist = decoder(approx_posterior.sample())\n", " return decoding_dist.mean()\n", "\n", "example_reconstruction = vae_mean(example_X).numpy().squeeze()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot examples against reconstructions\n", "\n", "f, axs = plt.subplots(2, 6, figsize=(16, 5))\n", "\n", "for j in range(6):\n", " axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')\n", " axs[1, j].imshow(example_reconstruction[j, :, :], cmap='binary')\n", " axs[0, j].axis('off')\n", " axs[1, j].axis('off')" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Generate an example - sample a z value, then sample a reconstruction from p(x|z)\n", "\n", "z = prior.sample(6)\n", "generated_x = decoder(z).mean()\n", "\n", "# Display generated_x\n", "\n", "f, axs = plt.subplots(1, 6, figsize=(16, 5))\n", "for j in range(6):\n", " axs[j].imshow(generated_x[j, :, :].numpy().squeeze(), cmap='binary')\n", " axs[j].axis('off')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }