{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# RealNVP for the LSUN bedroom dataset\n", "\n", "> In this post, we are take a look at an application for RealNVP. This is a homework assignment of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/realnvp_lsun.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "tfd = tfp.distributions\n", "tfpl = tfp.layers\n", "tfb = tfp.bijectors\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \"bedroom \n", " \"bedroom \n", " \"bedroom " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The LSUN Bedroom Dataset\n", "\n", "In this post, you will use a subset of the [LSUN dataset](https://www.yf.io/p/lsun). This is a large-scale image dataset with 10 scene and 20 object categories. A subset of the LSUN bedroom dataset has been provided, and has already been downsampled and preprocessed into smaller, fixed-size images.\n", "\n", "* F. Yu, A. Seff, Y. Zhang, S. Song, T. Funkhouser and J. Xia. \"LSUN: Construction of a Large-scale Image Dataset using Deep Learning with Humans in the Loop\". [arXiv:1506.03365](https://arxiv.org/abs/1506.03365), 10 Jun 2015 \n", "\n", "Our goal is to develop the RealNVP normalising flow architecture using bijector subclassing, and use it to train a generative model of the LSUN bedroom data subset. For full details on the RealNVP model, refer to the original paper:\n", "\n", "* L. Dinh, J. Sohl-Dickstein and S. Bengio. \"Density estimation using Real NVP\". [arXiv:1605.08803](https://arxiv.org/abs/1605.08803), 27 Feb 2017." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the dataset\n", "\n", "The following functions will be useful for loading and preprocessing the dataset. The subset you will use for this assignment consists of 10,000 training images, 1000 validation images and 1000 test images.\n", "\n", "The images have been downsampled to 32 x 32 x 3 in order to simplify the training process.\n", "\n", "> Note: Since the dataset is too large for maintaining in github. Please refer to the official homework assignment page in Coursera." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Functions for loading and preprocessing the images\n", "\n", "def load_image(img):\n", " img = tf.image.random_flip_left_right(img)\n", " return img, img\n", "\n", "def load_dataset(split):\n", " train_list_ds = tf.data.Dataset.from_tensor_slices(np.load('./dataset/lsun/{}.npy'.format(split)))\n", " train_ds = train_list_ds.map(load_image)\n", " return train_ds" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Load the training, validation and testing datasets splits\n", "\n", "train_ds = load_dataset('train')\n", "val_ds = load_dataset('val')\n", "test_ds = load_dataset('test')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Shuffle the datasets\n", "\n", "shuffle_buffer_size = 1000\n", "train_ds = train_ds.shuffle(shuffle_buffer_size)\n", "val_ds = val_ds.shuffle(shuffle_buffer_size)\n", "test_ds = test_ds.shuffle(shuffle_buffer_size)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Display a few examples\n", "\n", "n_img = 4\n", "f, axs = plt.subplots(n_img, n_img, figsize=(14, 14))\n", "\n", "for k, image in enumerate(train_ds.take(n_img**2)):\n", " i = k // n_img\n", " j = k % n_img\n", " axs[i, j].imshow(image[0])\n", " axs[i, j].axis('off')\n", "f.subplots_adjust(wspace=0.01, hspace=0.03)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Batch the Dataset objects\n", "\n", "batch_size = 64\n", "train_ds = train_ds.batch(batch_size)\n", "val_ds = val_ds.batch(batch_size)\n", "test_ds = test_ds.batch(batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Affine coupling layer\n", "\n", "We will begin the development of the RealNVP architecture with the core bijector that is called the _affine coupling layer_. This bijector can be described as follows: suppose that $x$ is a $D$-dimensional input, and let $dFigure 1. Spatial checkerboard mask (left) and channel-wise mask (right). From the original paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Custom model for log-scale and shift\n", "\n", "Here, I built a custom model for the shift and log-scale parameters that are used in the affine coupling layer bijector. In total, the network should have 14 layers (including the `Input` layer)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras import Input\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Conv2D, BatchNormalization\n", "from tensorflow.keras.regularizers import l2\n", "from tensorflow.keras.optimizers import Adam" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def get_conv_resnet(input_shape, filters):\n", " \"\"\"\n", " This function should build a CNN ResNet model according to the above specification,\n", " using the functional API. The function takes input_shape as an argument, which should be\n", " used to specify the shape in the Input layer, as well as a filters argument, which\n", " should be used to specify the number of filters in (some of) the convolutional layers.\n", " Your function should return the model.\n", " \"\"\"\n", " h0 = Input(shape=input_shape)\n", " \n", " # 1st Skip connection\n", " y = Conv2D(filters=filters, kernel_size=3, padding='SAME', activation='relu', kernel_regularizer=l2(l=5e-5))(h0)\n", " y = BatchNormalization()(y)\n", " y = Conv2D(filters=input_shape[-1], kernel_size=3, padding='SAME', activation='relu', kernel_regularizer=l2(l=5e-5))(y)\n", " y = BatchNormalization()(y)\n", " h1 = tf.math.add(y, h0)\n", " \n", " # 2nd skip connection\n", " y = Conv2D(filters=filters, kernel_size=3, padding='SAME', activation='relu', kernel_regularizer=l2(l=5e-5))(h1)\n", " y = BatchNormalization()(y)\n", " y = Conv2D(filters=input_shape[-1], kernel_size=3, padding='SAME', activation='relu', kernel_regularizer=l2(l=5e-5))(y)\n", " y = BatchNormalization()(y)\n", " y = tf.math.add(y, h1)\n", " h2 = Conv2D(filters=2 * input_shape[-1], kernel_size=3, padding='SAME', activation='linear', kernel_regularizer=l2(l=5e-5))(y)\n", " shift, log_scale = tf.split(h2, num_or_size_splits=2, axis=-1)\n", " y = tf.math.tanh(log_scale)\n", " model = Model(inputs=h0, outputs=[shift, y])\n", " return model" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "input_1 (InputLayer) [(None, 32, 32, 3)] 0 \n", "__________________________________________________________________________________________________\n", "conv2d (Conv2D) (None, 32, 32, 32) 896 input_1[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization (BatchNorma (None, 32, 32, 32) 128 conv2d[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 32, 32, 3) 867 batch_normalization[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization_1 (BatchNor (None, 32, 32, 3) 12 conv2d_1[0][0] \n", "__________________________________________________________________________________________________\n", "tf.math.add (TFOpLambda) (None, 32, 32, 3) 0 batch_normalization_1[0][0] \n", " input_1[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 32, 32, 32) 896 tf.math.add[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization_2 (BatchNor (None, 32, 32, 32) 128 conv2d_2[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 32, 32, 3) 867 batch_normalization_2[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization_3 (BatchNor (None, 32, 32, 3) 12 conv2d_3[0][0] \n", "__________________________________________________________________________________________________\n", "tf.math.add_1 (TFOpLambda) (None, 32, 32, 3) 0 batch_normalization_3[0][0] \n", " tf.math.add[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d_4 (Conv2D) (None, 32, 32, 6) 168 tf.math.add_1[0][0] \n", "__________________________________________________________________________________________________\n", "tf.split (TFOpLambda) [(None, 32, 32, 3), 0 conv2d_4[0][0] \n", "__________________________________________________________________________________________________\n", "tf.math.tanh (TFOpLambda) (None, 32, 32, 3) 0 tf.split[0][1] \n", "==================================================================================================\n", "Total params: 3,974\n", "Trainable params: 3,834\n", "Non-trainable params: 140\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "# Test your function and print the model summary\n", "\n", "conv_resnet = get_conv_resnet((32, 32, 3), 32)\n", "conv_resnet.summary()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Plot the model graph\n", "\n", "tf.keras.utils.plot_model(conv_resnet, show_layer_names=False, rankdir='LR')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 32, 32, 3)\n", "(1, 32, 32, 3)\n" ] } ], "source": [ "# Check the output shapes are as expected\n", "\n", "print(conv_resnet(tf.random.normal((1, 32, 32, 3)))[0].shape)\n", "print(conv_resnet(tf.random.normal((1, 32, 32, 3)))[1].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Binary masks\n", "\n", "Now that you have a shift and log-scale model built, we will now implement the affine coupling layer. We will first need functions to create the binary masks $b$ as described above. The following function creates the spatial 'checkerboard' mask.\n", "\n", "It takes a rank-2 `shape` as input, which correspond to the `height` and `width` dimensions, as well as an `orientation` argument (an integer equal to `0` or `1`) that determines which way round the zeros and ones are entered into the Tensor." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Function to create the checkerboard mask\n", "\n", "def checkerboard_binary_mask(shape, orientation=0):\n", " height, width = shape[0], shape[1]\n", " height_range = tf.range(height)\n", " width_range = tf.range(width)\n", " height_odd_inx = tf.cast(tf.math.mod(height_range, 2), dtype=tf.bool)\n", " width_odd_inx = tf.cast(tf.math.mod(width_range, 2), dtype=tf.bool)\n", " odd_rows = tf.tile(tf.expand_dims(height_odd_inx, -1), [1, width])\n", " odd_cols = tf.tile(tf.expand_dims(width_odd_inx, 0), [height, 1])\n", " checkerboard_mask = tf.math.logical_xor(odd_rows, odd_cols)\n", " if orientation == 1:\n", " checkerboard_mask = tf.math.logical_not(checkerboard_mask)\n", " return tf.cast(tf.expand_dims(checkerboard_mask, -1), tf.float32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function creates a rank-3 Tensor to mask the `height`, `width` and `channels` dimensions of the input. We can take a look at this checkerboard mask for some example inputs below. In order to make the Tensors easier to inspect, we will squeeze out the single channel dimension (which is always 1 for this mask)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run the checkerboard_binary_mask function to see an example\n", "# NB: we squeeze the shape for easier viewing. The full shape is (4, 4, 1)\n", "\n", "tf.squeeze(checkerboard_binary_mask((4, 4), orientation=0))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The `orientation` should be 0 or 1, and determines which way round the binary entries are\n", "\n", "tf.squeeze(checkerboard_binary_mask((4, 4), orientation=1))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def channel_binary_mask(num_channels, orientation=0):\n", " \"\"\"\n", " This function takes an integer num_channels and orientation (0 or 1) as\n", " arguments. It should create a channel-wise binary mask with \n", " dtype=tf.float32, according to the above specification.\n", " The function should then return the binary mask.\n", " \"\"\"\n", " mask_list = []\n", "\n", " for i in range(num_channels):\n", " if i < num_channels // 2:\n", " mask_list.append(orientation)\n", " else:\n", " mask_list.append(not orientation)\n", "\n", " mask = tf.cast(tf.constant(np.array([[mask_list]])),dtype=tf.float32)\n", " return mask" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run your function to see an example channel-wise binary mask\n", "\n", "channel_binary_mask(6, orientation=0)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def forward(x, b, shift_and_log_scale_fn):\n", " \"\"\"\n", " This function takes the input Tensor x, binary mask b and callable\n", " shift_and_log_scale_fn as arguments.\n", " This function should implement the forward transformation in equation (5)\n", " and return the output Tensor y, which will have the same shape as x\n", " \"\"\"\n", " x_b = x * b\n", " shift, log_scale = shift_and_log_scale_fn(x_b)\n", " y = x_b + (1 - b) * (x * tf.math.exp(log_scale) + shift)\n", " return y\n", "\n", "def inverse(y, b, shift_and_log_scale_fn):\n", " \"\"\"\n", " This function takes the input Tensor x, binary mask b and callable\n", " shift_and_log_scale_fn as arguments.\n", " This function should implement the forward transformation in equation (6)\n", " and return the output Tensor y, which will have the same shape as x\n", " \"\"\"\n", " y_b = y * b\n", " shift, log_scale = shift_and_log_scale_fn(y_b)\n", " x = y_b + (1 - b) * (y - shift) * tf.math.exp(-log_scale)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The new bijector class also requires the `log_det_jacobian` methods to be implemented. Recall that the log of the Jacobian determinant of the forward transformation is given by $\\sum_{j}s(x_{1:d})_j$, where $s$ is the log-scale function of the affine coupling layer. " ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def forward_log_det_jacobian(x, b, shift_and_log_scale_fn):\n", " \"\"\"\n", " This function takes the input Tensor x, binary mask b and callable\n", " shift_and_log_scale_fn as arguments.\n", " This function should compute and return the log of the Jacobian determinant \n", " of the forward transformation in equation (5)\n", " \"\"\"\n", " x_b = x * b\n", " shift, log_scale = shift_and_log_scale_fn(x_b)\n", " return tf.reduce_sum(log_scale * (1 - b), [1, 2, 3])\n", " \n", "\n", "def inverse_log_det_jacobian(y, b, shift_and_log_scale_fn):\n", " \"\"\"\n", " This function takes the input Tensor y, binary mask b and callable\n", " shift_and_log_scale_fn as arguments.\n", " This function should compute and return the log of the Jacobian determinant \n", " of the forward transformation in equation (6)\n", " \"\"\"\n", " y_b = y * b\n", " shift, log_scale = shift_and_log_scale_fn(y_b)\n", " return tf.reduce_sum(-log_scale * (1 - b), [1, 2, 3])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class AffineCouplingLayer(tfb.Bijector):\n", " \"\"\"\n", " Class to implement the affine coupling layer.\n", " Complete the __init__ and _get_mask methods according to the instructions above.\n", " \"\"\"\n", "\n", " def __init__(self, shift_and_log_scale_fn, mask_type, orientation, **kwargs):\n", " \"\"\"\n", " The class initialiser takes the shift_and_log_scale_fn callable, mask_type,\n", " orientation and possibly extra keywords arguments. It should call the \n", " base class initialiser, passing any extra keyword arguments along. \n", " It should also set the required arguments as class attributes.\n", " \"\"\"\n", " super(AffineCouplingLayer, self).__init__(**kwargs, forward_min_event_ndims=3)\n", " self.shift_and_log_scale_fn = shift_and_log_scale_fn\n", " self.mask_type = mask_type\n", " self.orientation = orientation\n", " \n", " \n", " def _get_mask(self, shape):\n", " \"\"\"\n", " This internal method should use the binary mask functions above to compute\n", " and return the binary mask, according to the arguments passed in to the\n", " initialiser.\n", " \"\"\"\n", " height, width, channels = shape[-3:]\n", " \n", " if self.mask_type == 'checkerboard' :\n", " mask = checkerboard_binary_mask((height, width), self.orientation)\n", " elif self.mask_type == 'channel':\n", " mask = channel_binary_mask(channels, self.orientation)\n", " return mask\n", "\n", " def _forward(self, x):\n", " b = self._get_mask(x.shape)\n", " return forward(x, b, self.shift_and_log_scale_fn)\n", "\n", " def _inverse(self, y):\n", " b = self._get_mask(y.shape)\n", " return inverse(y, b, self.shift_and_log_scale_fn)\n", "\n", " def _forward_log_det_jacobian(self, x):\n", " b = self._get_mask(x.shape)\n", " return forward_log_det_jacobian(x, b, self.shift_and_log_scale_fn)\n", "\n", " def _inverse_log_det_jacobian(self, y):\n", " b = self._get_mask(y.shape)\n", " return inverse_log_det_jacobian(y, b, self.shift_and_log_scale_fn)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Test your function by creating an instance of the AffineCouplingLayer class\n", "\n", "affine_coupling_layer = AffineCouplingLayer(conv_resnet, 'channel', orientation=1, \n", " name='affine_coupling_layer')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([16, 32, 32, 3])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The following should return a Tensor of the same shape as the input\n", "\n", "affine_coupling_layer.forward(tf.random.normal((16, 32, 32, 3))).shape" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([16])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The following should compute a log_det_jacobian for each event in the batch\n", "\n", "affine_coupling_layer.forward_log_det_jacobian(tf.random.normal((16, 32, 32, 3)), event_ndims=3).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combining the affine coupling layers\n", "\n", "In the affine coupling layer, part of the input remains unchanged in the transformation $(5)$. In order to allow transformation of all of the input, several coupling layers are composed, with the orientation of the mask being reversed in subsequent layers.\n", "\n", "![Coupling layers](image/alternating_masks.png)\n", "
Figure 2. RealNVP alternates the orientation of masks from one affine coupling layer to the next. From the original paper.
\n", "\n", "Our model design will be similar to the original architecture; we will compose three affine coupling layers with checkerboard masking, followed by a batch normalization bijector (`tfb.BatchNormalization` is a built-in bijector), followed by a squeezing operation, followed by three more affine coupling layers with channel-wise masking and a final batch normalization bijector. \n", "\n", "The squeezing operation divides the spatial dimensions into 2x2 squares, and reshapes a Tensor of shape `(H, W, C)` into a Tensor of shape `(H // 2, W // 2, 4 * C)` as shown in Figure 1.\n", "\n", "The squeezing operation is also a bijective operation, and has been provided for you in the class below." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# Bijector class for the squeezing operation\n", "\n", "class Squeeze(tfb.Bijector):\n", " \n", " def __init__(self, name='Squeeze', **kwargs):\n", " super(Squeeze, self).__init__(forward_min_event_ndims=3, is_constant_jacobian=True, \n", " name=name, **kwargs)\n", "\n", " def _forward(self, x):\n", " input_shape = x.shape\n", " height, width, channels = input_shape[-3:]\n", " y = tfb.Reshape((height // 2, 2, width // 2, 2, channels), event_shape_in=(height, width, channels))(x)\n", " y = tfb.Transpose(perm=[0, 2, 1, 3, 4])(y)\n", " y = tfb.Reshape((height // 2, width // 2, 4 * channels),\n", " event_shape_in=(height // 2, width // 2, 2, 2, channels))(y)\n", " return y\n", "\n", " def _inverse(self, y):\n", " input_shape = y.shape\n", " height, width, channels = input_shape[-3:]\n", " x = tfb.Reshape((height, width, 2, 2, channels // 4), event_shape_in=(height, width, channels))(y)\n", " x = tfb.Transpose(perm=[0, 2, 1, 3, 4])(x)\n", " x = tfb.Reshape((2 * height, 2 * width, channels // 4),\n", " event_shape_in=(height, 2, width, 2, channels // 4))(x)\n", " return x\n", "\n", " def _forward_log_det_jacobian(self, x):\n", " return tf.constant(0., x.dtype)\n", "\n", " def _inverse_log_det_jacobian(self, y):\n", " return tf.constant(0., y.dtype)\n", "\n", " def _forward_event_shape_tensor(self, input_shape):\n", " height, width, channels = input_shape[-3], input_shape[-2], input_shape[-1]\n", " return height // 2, width // 2, 4 * channels\n", "\n", " def _inverse_event_shape_tensor(self, output_shape):\n", " height, width, channels = output_shape[-3], output_shape[-2], output_shape[-1]\n", " return height * 2, width * 2, channels // 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see the effect of the squeezing operation on some example inputs in the cells below. In the forward transformation, each spatial dimension is halved, whilst the channel dimension is multiplied by 4. The opposite happens in the inverse transformation." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([10, 16, 16, 12])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the Squeeze bijector\n", "\n", "squeeze = Squeeze()\n", "squeeze(tf.ones((10, 32, 32, 3))).shape" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([10, 8, 8, 24])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the inverse operation\n", "\n", "squeeze.inverse(tf.ones((10, 4, 4, 96))).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now construct a block of coupling layers according to the architecture described above. Our Chained bijector has specific structure,\n", "\n", "* Three `AffineCouplingLayer` bijectors with `\"checkerboard\"` masking with orientations `0, 1, 0` respectively\n", "* A `BatchNormalization` bijector\n", "* A `Squeeze` bijector\n", "* Three more `AffineCouplingLayer` bijectors with `\"channel\"` masking with orientations `0, 1, 0` respectively\n", "* Another `BatchNormalization` bijector\n", "\n", "The function takes the following arguments:\n", "* `shift_and_log_scale_fns`: a list or tuple of six conv_resnet models\n", " * The first three models in this list are used in the three coupling layers with checkerboard masking\n", " * The last three models in this list are used in the three coupling layers with channel masking\n", "* `squeeze`: an instance of the `Squeeze` bijector" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def realnvp_block(shift_and_log_scale_fns, squeeze):\n", " \"\"\"\n", " This function takes a list or tuple of six conv_resnet models, and an \n", " instance of the Squeeze bijector.\n", " The function should construct the chain of bijectors described above,\n", " using the conv_resnet models in the coupling layers.\n", " The function should then return the chained bijector.\n", " \"\"\"\n", " bijectors = []\n", " orientations = [0, 1, 0]\n", " for i in range(3):\n", " bijectors.append(AffineCouplingLayer(shift_and_log_scale_fn=shift_and_log_scale_fns[i], \n", " mask_type='checkerboard',\n", " orientation=orientations[i]))\n", " bijectors.append(tfb.BatchNormalization()) \n", " bijectors.append(squeeze)\n", " \n", " for i in range(3, 6):\n", " bijectors.append(AffineCouplingLayer(shift_and_log_scale_fn=shift_and_log_scale_fns[i],\n", " mask_type='channel',\n", " orientation=orientations[i % 3]))\n", " bijectors.append(tfb.BatchNormalization())\n", " \n", " flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))\n", " return flow_bijector" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Run your function to create an instance of the bijector\n", "\n", "checkerboard_fns = []\n", "for _ in range(3):\n", " checkerboard_fns.append(get_conv_resnet((32, 32, 3), 512))\n", "channel_fns = []\n", "for _ in range(3):\n", " channel_fns.append(get_conv_resnet((16, 16, 12), 512))\n", " \n", "block = realnvp_block(checkerboard_fns + channel_fns, squeeze)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([10, 16, 16, 12])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test the bijector on a dummy input\n", "\n", "block.forward(tf.random.normal((10, 32, 32, 3))).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiscale architecture\n", "\n", "The final component of the RealNVP is the multiscale architecture. The squeeze operation reduces the spatial dimensions but increases the channel dimensions. After one of the blocks of coupling-squeeze-coupling that you have implemented above, half of the dimensions are factored out as latent variables, while the other half is further processed through subsequent layers. This results in latent variables that represent different scales of features in the model.\n", "\n", "![Multiscale architecture](image/multiscale.png)\n", "
Figure 3. RealNVP creates latent variables at different scales by factoring out half of the dimensions at each scale. From the original paper.
\n", "\n", "The final scale does not use the squeezing operation, and instead applies four affine coupling layers with alternating checkerboard masks.\n", "\n", "The multiscale architecture for two latent variable scales is implemented for you in the following bijector." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# Bijector to implement the multiscale architecture\n", "\n", "class RealNVPMultiScale(tfb.Bijector):\n", " \n", " def __init__(self, **kwargs):\n", " super(RealNVPMultiScale, self).__init__(forward_min_event_ndims=3, **kwargs)\n", "\n", " # First level\n", " shape1 = (32, 32, 3) # Input shape\n", " shape2 = (16, 16, 12) # Shape after the squeeze operation\n", " shape3 = (16, 16, 6) # Shape after factoring out the latent variable\n", " self.conv_resnet1 = get_conv_resnet(shape1, 64)\n", " self.conv_resnet2 = get_conv_resnet(shape1, 64)\n", " self.conv_resnet3 = get_conv_resnet(shape1, 64)\n", " self.conv_resnet4 = get_conv_resnet(shape2, 128)\n", " self.conv_resnet5 = get_conv_resnet(shape2, 128)\n", " self.conv_resnet6 = get_conv_resnet(shape2, 128)\n", " self.squeeze = Squeeze()\n", " self.block1 = realnvp_block([self.conv_resnet1, self.conv_resnet2,\n", " self.conv_resnet3, self.conv_resnet4,\n", " self.conv_resnet5, self.conv_resnet6], self.squeeze)\n", "\n", " # Second level\n", " self.conv_resnet7 = get_conv_resnet(shape3, 128)\n", " self.conv_resnet8 = get_conv_resnet(shape3, 128)\n", " self.conv_resnet9 = get_conv_resnet(shape3, 128)\n", " self.conv_resnet10 = get_conv_resnet(shape3, 128)\n", " self.coupling_layer1 = AffineCouplingLayer(self.conv_resnet7, 'checkerboard', 0)\n", " self.coupling_layer2 = AffineCouplingLayer(self.conv_resnet8, 'checkerboard', 1)\n", " self.coupling_layer3 = AffineCouplingLayer(self.conv_resnet9, 'checkerboard', 0)\n", " self.coupling_layer4 = AffineCouplingLayer(self.conv_resnet10, 'checkerboard', 1)\n", " self.block2 = tfb.Chain([self.coupling_layer4, self.coupling_layer3,\n", " self.coupling_layer2, self.coupling_layer1])\n", "\n", " def _forward(self, x):\n", " h1 = self.block1.forward(x)\n", " z1, h2 = tf.split(h1, 2, axis=-1)\n", " z2 = self.block2.forward(h2)\n", " return tf.concat([z1, z2], axis=-1)\n", " \n", " def _inverse(self, y):\n", " z1, z2 = tf.split(y, 2, axis=-1)\n", " h2 = self.block2.inverse(z2)\n", " h1 = tf.concat([z1, h2], axis=-1)\n", " return self.block1.inverse(h1)\n", "\n", " def _forward_log_det_jacobian(self, x):\n", " log_det1 = self.block1.forward_log_det_jacobian(x, event_ndims=3)\n", " h1 = self.block1.forward(x)\n", " _, h2 = tf.split(h1, 2, axis=-1)\n", " log_det2 = self.block2.forward_log_det_jacobian(h2, event_ndims=3)\n", " return log_det1 + log_det2\n", "\n", " def _inverse_log_det_jacobian(self, y):\n", " z1, z2 = tf.split(y, 2, axis=-1)\n", " h2 = self.block2.inverse(z2)\n", " log_det2 = self.block2.inverse_log_det_jacobian(z2, event_ndims=3)\n", " h1 = tf.concat([z1, h2], axis=-1)\n", " log_det1 = self.block1.inverse_log_det_jacobian(h1, event_ndims=3)\n", " return log_det1 + log_det2\n", "\n", " def _forward_event_shape_tensor(self, input_shape):\n", " height, width, channels = input_shape[-3], input_shape[-2], input_shape[-1]\n", " return height // 4, width // 4, 16 * channels\n", "\n", " def _inverse_event_shape_tensor(self, output_shape):\n", " height, width, channels = output_shape[-3], output_shape[-2], output_shape[-1]\n", " return 4 * height, 4 * width, channels // 16" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Create an instance of the multiscale architecture\n", "\n", "multiscale_bijector = RealNVPMultiScale()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data preprocessing bijector\n", "\n", "We will also preprocess the image data before sending it through the RealNVP model. To do this, for a Tensor $x$ of pixel values in $[0, 1]^D$, we transform $x$ according to the following:\n", "\n", "$$\n", "T(x) = \\text{logit}\\left(\\alpha + (1 - 2\\alpha)x\\right),\\tag{7}\n", "$$\n", "\n", "where $\\alpha$ is a parameter, and the logit function is the inverse of the sigmoid function, and is given by \n", "\n", "$$\n", "\\text{logit}(p) = \\log (p) - \\log (1 - p).\n", "$$" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "def get_preprocess_bijector(alpha):\n", " \"\"\"\n", " This function should create a chained bijector that computes the \n", " transformation T in equation (7) above.\n", " This can be computed using in-built bijectors from the bijectors module.\n", " Your function should then return the chained bijector.\n", " \"\"\"\n", " return tfb.Chain([tfb.Invert(tfb.Sigmoid()), tfb.Shift(shift=alpha), tfb.Scale(scale=(1 - 2 * alpha))])" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# Create an instance of the preprocess bijector\n", "\n", "preprocess = get_preprocess_bijector(0.05)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the RealNVP model\n", "\n", "Finally, we will use our RealNVP model to train\n", "\n", "We will use the following model class to help with the training process." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# Helper class for training\n", "\n", "class RealNVPModel(Model):\n", "\n", " def __init__(self, **kwargs):\n", " super(RealNVPModel, self).__init__(**kwargs)\n", " self.preprocess = get_preprocess_bijector(0.05)\n", " self.realnvp_multiscale = RealNVPMultiScale()\n", " self.bijector = tfb.Chain([self.realnvp_multiscale, self.preprocess])\n", " \n", " def build(self, input_shape):\n", " output_shape = self.bijector(tf.expand_dims(tf.zeros(input_shape[1:]), axis=0)).shape\n", " self.base = tfd.Independent(tfd.Normal(loc=tf.zeros(output_shape[1:]), scale=1.),\n", " reinterpreted_batch_ndims=3)\n", " self._bijector_variables = (\n", " list(self.bijector.variables))\n", " self.flow = tfd.TransformedDistribution(\n", " distribution=self.base,\n", " bijector=tfb.Invert(self.bijector),\n", " )\n", " super(RealNVPModel, self).build(input_shape)\n", "\n", " def call(self, inputs, training=None, **kwargs):\n", " return self.flow\n", "\n", " def sample(self, batch_size):\n", " sample = self.base.sample(batch_size)\n", " return self.bijector.inverse(sample)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# Create an instance of the RealNVPModel class\n", "\n", "realnvp_model = RealNVPModel()\n", "realnvp_model.build((1, 32, 32, 3))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total trainable variables:\n", "315156\n" ] } ], "source": [ "# Compute the number of variables in the model\n", "\n", "print(\"Total trainable variables:\")\n", "print(sum([np.prod(v.shape) for v in realnvp_model.trainable_variables]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the model's `call` method returns the `TransformedDistribution` object. Also, we have set up our datasets to return the input image twice as a 2-tuple. This is so we can train our model with negative log-likelihood as normal." ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# Define the negative log-likelihood loss function\n", "\n", "def nll(y_true, y_pred):\n", " return -y_pred.log_prob(y_true)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is recommended to use the GPU accelerator hardware on Colab to train this model, as it can take some time to train. Note that it is not required to train the model in order to pass this assignment. For optimal results, a larger model should be trained for longer." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "938/938 [==============================] - 97s 94ms/step - loss: -2471.9236 - val_loss: -5264.1875\n", "Epoch 2/20\n", "938/938 [==============================] - 86s 91ms/step - loss: -5968.5244 - val_loss: -6436.4800\n", "Epoch 3/20\n", "938/938 [==============================] - 85s 91ms/step - loss: -6787.8242 - val_loss: -7112.7544\n", "Epoch 4/20\n", "938/938 [==============================] - 87s 92ms/step - loss: -7243.9448 - val_loss: -7462.9360\n", "Epoch 5/20\n", "938/938 [==============================] - 86s 92ms/step - loss: -7573.0698 - val_loss: -7738.2793\n", "Epoch 6/20\n", "938/938 [==============================] - 86s 92ms/step - loss: -7801.1133 - val_loss: -7942.5913\n", "Epoch 7/20\n", "938/938 [==============================] - 86s 91ms/step - loss: -7978.1836 - val_loss: -8082.8511\n", "Epoch 8/20\n", "938/938 [==============================] - 86s 92ms/step - loss: -8132.2954 - val_loss: -8124.5166\n", "Epoch 9/20\n", "938/938 [==============================] - 88s 94ms/step - loss: -8242.4521 - val_loss: -8312.3848\n", "Epoch 10/20\n", "938/938 [==============================] - 89s 95ms/step - loss: -8339.6230 - val_loss: -8414.3115\n", "Epoch 11/20\n", "938/938 [==============================] - 90s 96ms/step - loss: -8124.6929 - val_loss: -8372.0430\n", "Epoch 12/20\n", "938/938 [==============================] - 89s 95ms/step - loss: -8453.0684 - val_loss: -8446.0449\n", "Epoch 13/20\n", "938/938 [==============================] - 87s 92ms/step - loss: -8527.1289 - val_loss: -8574.1846\n", "Epoch 14/20\n", "938/938 [==============================] - 86s 92ms/step - loss: -8586.6006 - val_loss: -8650.7441\n", "Epoch 15/20\n", "938/938 [==============================] - 88s 94ms/step - loss: -8639.4971 - val_loss: -8638.4111\n", "Epoch 16/20\n", "938/938 [==============================] - 86s 91ms/step - loss: -8680.0918 - val_loss: -8715.3389\n", "Epoch 17/20\n", "938/938 [==============================] - 89s 94ms/step - loss: -8709.9248 - val_loss: -8777.8477\n", "Epoch 18/20\n", "938/938 [==============================] - 87s 92ms/step - loss: -8758.4932 - val_loss: -8764.7441\n", "Epoch 19/20\n", "938/938 [==============================] - 86s 92ms/step - loss: -8789.6807 - val_loss: -8831.9355\n", "Epoch 20/20\n", "938/938 [==============================] - 87s 92ms/step - loss: -8811.7500 - val_loss: -8797.5312\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compile and train the model\n", "\n", "realnvp_model.compile(loss=nll, optimizer=Adam())\n", "realnvp_model.fit(train_ds, validation_data=val_ds, epochs=20)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "157/157 [==============================] - 3s 19ms/step - loss: -8791.1699\n" ] }, { "data": { "text/plain": [ "-8791.169921875" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the model\n", "\n", "realnvp_model.evaluate(test_ds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate some samples" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/chanseok/anaconda3/envs/torch/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:2183: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n", " warnings.warn('`layer.apply` is deprecated and '\n" ] } ], "source": [ "# Sample from the model\n", "\n", "samples = realnvp_model.sample(8).numpy()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Display the samples\n", "\n", "n_img = 8\n", "f, axs = plt.subplots(2, n_img // 2, figsize=(14, 7))\n", "\n", "for k, image in enumerate(samples):\n", " i = k % 2\n", " j = k // 2\n", " axs[i, j].imshow(np.clip(image, 0., 1.))\n", " axs[i, j].axis('off')\n", "f.subplots_adjust(wspace=0.01, hspace=0.03)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }