{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# KL Divergence Layers\n", "\n", "> In this post, we will cover the easy way to handle KL divergence with tensorflow probability layer object. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/fashion_mnist_generated2.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "tfd = tfp.distributions\n", "tfpl = tfp.layers\n", "tfb = tfp.bijectors\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Samples\n", "```python\n", "latent_size=4\n", "prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))\n", "\n", "encoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(12,)),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size),\n", " tfpl.KLDivergenceAddLoss(prior) # automatically add loss function into a model to be optimized later on\n", "])\n", "\n", "decoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(latent_size,)),\n", " Dense(tfpl.IndependentNormal.params_size(12)),\n", " tfpl.IndepedentNormal(12)\n", "])\n", "\n", "vae = Model(inputs=encoder.input, outputs=decoder(encoder.output))\n", "vae.compile(loss=lambda x, pred: -pred.log_prob(x))\n", "vae.fit(train_data, epochs=20)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or you can implement KL Divergence that can use exact value by using `use_exact_kl` keyword. Or you can also multiply weights in KL term." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "encoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(12,)),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size),\n", " tfpl.KLDivergenceAddLoss(prior, use_exact_kl=False, weight=10) # Use MC sampling for KL divergence, then weight it by 10 \n", "])\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "encoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(12,)),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size,\n", " convert_to_tensor_fn=tfp.distributions.Distribution.sample),\n", " tfpl.KLDivergenceAddLoss(prior) # automatically add loss function into a model to be optimized later on\n", "])\n", "```\n", "\n", "In this case, the output of encoder will be the sample from multivariate normal distribution. Note that, above example is for Computing KL divergence. If you use `convert_to_tensor_fn` to `mean` or `mode`, then it will be the tensor that would be used in the approximation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "encoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(12,)),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size),\n", " tfpl.KLDivergenceAddLoss(prior, use_exact_kl=False, weight=10,\n", " test_points_fn=lambda q: q.sample(10), # 10 samples for test points\n", " test_points_reduce_axis=0) # automatically add loss function into a model to be optimized later on\n", "])\n", "```\n", "So at that case, test point function is required to compute the estimation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternative way to implement KL divergence is to use `KLDivergenRegularizer` for the regularizer.\n", "\n", "```python\n", "encoder = Sequential([\n", " Dense(64, activation='relu', input_shape=(12,)),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size,\n", " activity_regularizer=tfpl.KLDivergenceRegularizer(\n", " prior, weight=10, use_exact_kl=False,\n", " test_points_fn=lambda q: q.sample(10),\n", " test_points_reduce_axis=0))\n", "])\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Sequential, Model\n", "from tensorflow.keras.layers import Dense, Flatten, Reshape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Import fashion MNIST\n", "\n", "(X_train, _), (X_test, _) = tf.keras.datasets.fashion_mnist.load_data()\n", "X_train = X_train.astype('float32') / 256. + 0.5 / 256\n", "X_test = X_test.astype('float32') / 256. + 0.5 / 256\n", "example_X = X_test[:16]\n", "\n", "batch_size = 32\n", "X_train = tf.data.Dataset.from_tensor_slices((X_train, X_train)).batch(batch_size)\n", "X_test = tf.data.Dataset.from_tensor_slices((X_test, X_test)).batch(batch_size)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Define latent_size and the prior, p(z)\n", "\n", "latent_size = 4\n", "prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Define the encoding distribution using a tfpl.KLDivergenceAddLoss layer\n", "\n", "event_shape = (28, 28)\n", "\n", "encoder = Sequential([\n", " Flatten(input_shape=event_shape),\n", " Dense(128, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(32, activation='relu'),\n", " Dense(16, activation='relu'),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size),\n", " tfpl.KLDivergenceAddLoss(prior) # estimate KL[ q(z|x) || p(z)]\n", "])\n", "\n", "# Samples z_j from q(z | x_j)\n", "# then computes log q(z_j | x_j) - log p(z_j)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# See how `KLDivergenceAddLoss` affects `encoder.losses`\n", "# encoder.losses before the network has received any inputs\n", "\n", "encoder.losses" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pass a batch of images through the encoder\n", "\n", "encoder(example_X)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# See how encoder.losses has changed\n", "\n", "encoder.losses" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Re-specify the encoder using `weight` and `test_points_fn`\n", "\n", "encoder = Sequential([\n", " Flatten(input_shape=event_shape),\n", " Dense(128, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(32, activation='relu'),\n", " Dense(16, activation='relu'),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size),\n", " tfpl.KLDivergenceAddLoss(prior,\n", " use_exact_kl=False,\n", " weight=1.5,\n", " test_points_fn=lambda q: q.sample(10),\n", " test_points_reduce_axis=0) # estimate KL[ q(z|x) || p(z)]\n", "])\n", "\n", "# (n_samples, batch_size, dim_z)\n", "# z_{ij} is the ith sample for x_j (is at (i, j, :) in tensor of samples)\n", "# is mapped to log q(z_{ij}|x_j) - log p(z_{ij})\n", "# => tensor of KL Divergences has sape (n_samples, batch_size)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Replacing `KLDivergenceAddLoss` with `KLDivergenceRegularizer` in the previous layer\n", "divergence_regularizer = tfpl.KLDivergenceRegularizer(prior,\n", " use_exact_kl=False,\n", " test_points_fn=lambda q: q.sample(5),\n", " test_points_reduce_axis=0)\n", "\n", "encoder = Sequential([\n", " Flatten(input_shape=event_shape),\n", " Dense(128, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(32, activation='relu'),\n", " Dense(16, activation='relu'),\n", " Dense(tfpl.MultivariateNormalTriL.params_size(latent_size)),\n", " tfpl.MultivariateNormalTriL(latent_size,\n", " activity_regularizer=divergence_regularizer),\n", "])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Specify the decoder, p(x|z)\n", "\n", "decoder = Sequential([\n", " Dense(16, activation='relu', input_shape=(latent_size,)),\n", " Dense(32, activation='relu'),\n", " Dense(64, activation='relu'),\n", " Dense(128, activation='relu'),\n", " Dense(2*event_shape[0]*event_shape[1], activation='exponential'),\n", " Reshape((event_shape[0], event_shape[1], 2)),\n", " tfpl.DistributionLambda(\n", " lambda t: tfd.Independent(\n", " tfd.Beta(concentration1=t[..., 0],\n", " concentration0=t[..., 1])\n", " )\n", " )\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: If you faced the error like this \"NotImplementedError: Cannot convert a symbolic Tensor (gradients/stateless_random_gamma/StatelessRandomGammaV2_grad/sub:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported\", this is the problem on numpy side. You need to install numpy with `1.19.x` instead of `1.20.x`. See the [reference](https://stackoverflow.com/questions/58479556/notimplementederror-cannot-convert-a-symbolic-tensor-2nd-target0-to-a-numpy)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Connect the encoder and decoder to form the VAE\n", "\n", "vae = Model(inputs=encoder.inputs, outputs=decoder(encoder.outputs))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Define a loss that only estimate the expected reconstruction error,\n", "# -E_{z ~ q(z | x)}[log p(x | z)]\n", "\n", "def log_loss(X_true, p_x_given_z):\n", " return -tf.reduce_sum(p_x_given_z.log_prob(X_true))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "1875/1875 [==============================] - 19s 9ms/step - loss: -53792.4219 - val_loss: -63781.0430\n", "Epoch 2/10\n", "1875/1875 [==============================] - 16s 9ms/step - loss: -63665.5078 - val_loss: -64181.7930\n", "Epoch 3/10\n", "1875/1875 [==============================] - 16s 8ms/step - loss: -66448.9219 - val_loss: -67975.6172\n", "Epoch 4/10\n", "1875/1875 [==============================] - 16s 9ms/step - loss: -68327.5859 - val_loss: -70727.4141\n", "Epoch 5/10\n", "1875/1875 [==============================] - 16s 8ms/step - loss: -70031.3906 - val_loss: -70685.8516\n", "Epoch 6/10\n", "1875/1875 [==============================] - 15s 8ms/step - loss: -71203.7734 - val_loss: -64922.7461\n", "Epoch 7/10\n", "1875/1875 [==============================] - 16s 8ms/step - loss: -72169.3125 - val_loss: -73782.7109\n", "Epoch 8/10\n", "1875/1875 [==============================] - 15s 8ms/step - loss: -73042.2422 - val_loss: -70419.5547\n", "Epoch 9/10\n", "1875/1875 [==============================] - 15s 8ms/step - loss: -73685.4453 - val_loss: -73172.9297\n", "Epoch 10/10\n", "1875/1875 [==============================] - 16s 8ms/step - loss: -74290.2891 - val_loss: -75949.5312\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compile and fit the model\n", "\n", "vae.compile(loss=log_loss)\n", "vae.fit(X_train, validation_data=X_test, epochs=10)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Generate an example reconstruction\n", "\n", "example_reconstruction = vae(example_X).mean().numpy().squeeze()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot the example reconstructions\n", "\n", "f, axs = plt.subplots(2, 6, figsize=(16, 5))\n", "\n", "for j in range(6):\n", " axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')\n", " axs[1, j].imshow(example_reconstruction[j, :, :].squeeze(), cmap='binary')\n", " axs[0, j].axis('off')\n", " axs[1, j].axis('off')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Generate an example reconstruction\n", "\n", "example_reconstruction = vae(example_X).sample().numpy().squeeze()\n", "\n", "# Plot the example reconstructions\n", "\n", "f, axs = plt.subplots(2, 6, figsize=(16, 5))\n", "\n", "for j in range(6):\n", " axs[0, j].imshow(example_X[j, :, :].squeeze(), cmap='binary')\n", " axs[1, j].imshow(example_reconstruction[j, :, :].squeeze(), cmap='binary')\n", " axs[0, j].axis('off')\n", " axs[1, j].axis('off')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }