{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training Tybalt models with two hidden layers\n", "\n", "**Gregory Way 2017**\n", "\n", "This script is an extension of [tybalt_vae.ipynb](tybalt_vae.ipynb). See that script for more details about the base model. Here, I train two alternative Tybalt models with different architectures. Both architectures have two hidden layers:\n", "\n", "1. **Model A**: 5000 input -> 100 hidden -> 100 latent -> 100 hidden -> 5000 input\n", "2. **Model B**: 5000 input -> 300 hidden -> 100 latent -> 300 hidden -> 5000 input\n", "\n", "This notebook trains _both_ models. The optimal hyperparameters were selected through a grid search for each model independently.\n", "\n", "The original tybalt model compressed 5000 input genes into 100 latent features in a single layer.\n", "\n", "Much of this script is inspired by the [keras variational_autoencoder.py example](https://github.com/fchollet/keras/blob/master/examples/variational_autoencoder.py)\n", "\n", "## Output\n", "\n", "For both models, the script will output:\n", "\n", "1. The learned latent feature matrix\n", "2. Encoder and Decoder keras models with pretrained weights\n", "3. An abstracted weight matrix" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorflow as tf\n", "from keras.layers import Input, Dense, Lambda, Layer, Activation\n", "from keras.layers.normalization import BatchNormalization\n", "from keras.models import Model, Sequential\n", "from keras import backend as K\n", "from keras import metrics, optimizers\n", "from keras.callbacks import Callback\n", "import keras\n", "\n", "import pydot\n", "import graphviz\n", "from keras.utils import plot_model\n", "from IPython.display import SVG\n", "from keras.utils.vis_utils import model_to_dot" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.0.5\n" ] }, { "data": { "text/plain": [ "'1.2.1'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(keras.__version__)\n", "tf.__version__" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "plt.style.use('seaborn-notebook')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "np.random.seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Functions and Classes\n", "\n", "This will facilitate connections between layers and also custom hyperparameters" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Function for reparameterization trick to make model differentiable\n", "def sampling(args):\n", " \n", " import tensorflow as tf\n", " # Function with args required for Keras Lambda function\n", " z_mean, z_log_var = args\n", "\n", " # Draw epsilon of the same shape from a standard normal distribution\n", " epsilon = K.random_normal(shape=tf.shape(z_mean), mean=0.,\n", " stddev=epsilon_std)\n", " \n", " # The latent vector is non-deterministic and differentiable\n", " # in respect to z_mean and z_log_var\n", " z = z_mean + K.exp(z_log_var / 2) * epsilon\n", " return z\n", "\n", "\n", "class CustomVariationalLayer(Layer):\n", " \"\"\"\n", " Define a custom layer that learns and performs the training\n", "\n", " \"\"\"\n", " def __init__(self, var_layer, mean_layer, **kwargs):\n", " # https://keras.io/layers/writing-your-own-keras-layers/\n", " self.is_placeholder = True\n", " self.var_layer = var_layer\n", " self.mean_layer = mean_layer\n", " super(CustomVariationalLayer, self).__init__(**kwargs)\n", "\n", " def vae_loss(self, x_input, x_decoded):\n", " reconstruction_loss = original_dim * metrics.binary_crossentropy(x_input, x_decoded)\n", " kl_loss = - 0.5 * K.sum(1 + self.var_layer - K.square(self.mean_layer) - \n", " K.exp(self.var_layer), axis=-1)\n", " return K.mean(reconstruction_loss + (K.get_value(beta) * kl_loss))\n", "\n", " def call(self, inputs):\n", " x = inputs[0]\n", " x_decoded = inputs[1]\n", " loss = self.vae_loss(x, x_decoded)\n", " self.add_loss(loss, inputs=inputs)\n", " # We won't actually use the output.\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementing Warm-up as described in Sonderby et al. LVAE\n", "\n", "This is modified code from https://github.com/fchollet/keras/issues/2595" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class WarmUpCallback(Callback):\n", " def __init__(self, beta, kappa):\n", " self.beta = beta\n", " self.kappa = kappa\n", " # Behavior on each epoch\n", " def on_epoch_end(self, epoch, logs={}):\n", " if K.get_value(self.beta) <= 1:\n", " K.set_value(self.beta, K.get_value(self.beta) + self.kappa)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tybalt Model\n", "\n", "The following class implements a Tybalt model with given input hyperparameters. Currently, only a two hidden layer model is supported." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Tybalt():\n", " \"\"\"\n", " Facilitates the training and output of tybalt model trained on TCGA RNAseq gene expression data\n", " \"\"\"\n", " def __init__(self, original_dim, hidden_dim, latent_dim,\n", " batch_size, epochs, learning_rate, kappa, beta):\n", " self.original_dim = original_dim\n", " self.hidden_dim = hidden_dim\n", " self.latent_dim = latent_dim\n", " self.batch_size = batch_size\n", " self.epochs = epochs\n", " self.learning_rate = learning_rate\n", " self.kappa = kappa\n", " self.beta = beta\n", "\n", " def build_encoder_layer(self):\n", " # Input place holder for RNAseq data with specific input size\n", " self.rnaseq_input = Input(shape=(self.original_dim, ))\n", "\n", " # Input layer is compressed into a mean and log variance vector of size `latent_dim`\n", " # Each layer is initialized with glorot uniform weights and each step (dense connections, batch norm,\n", " # and relu activation) are funneled separately\n", " # Each vector of length `latent_dim` are connected to the rnaseq input tensor\n", " hidden_dense_linear = Dense(self.hidden_dim, kernel_initializer='glorot_uniform')(self.rnaseq_input)\n", " hidden_dense_batchnorm = BatchNormalization()(hidden_dense_linear)\n", " hidden_encoded = Activation('relu')(hidden_dense_batchnorm)\n", "\n", " z_mean_dense_linear = Dense(self.latent_dim, kernel_initializer='glorot_uniform')(hidden_encoded)\n", " z_mean_dense_batchnorm = BatchNormalization()(z_mean_dense_linear)\n", " self.z_mean_encoded = Activation('relu')(z_mean_dense_batchnorm)\n", "\n", " z_log_var_dense_linear = Dense(self.latent_dim, kernel_initializer='glorot_uniform')(hidden_encoded)\n", " z_log_var_dense_batchnorm = BatchNormalization()(z_log_var_dense_linear)\n", " self.z_log_var_encoded = Activation('relu')(z_log_var_dense_batchnorm)\n", "\n", " # return the encoded and randomly sampled z vector\n", " # Takes two keras layers as input to the custom sampling function layer with a `latent_dim` output\n", " self.z = Lambda(sampling, output_shape=(self.latent_dim, ))([self.z_mean_encoded, self.z_log_var_encoded])\n", " \n", " def build_decoder_layer(self):\n", " # The decoding layer is much simpler with a single layer glorot uniform initialized and sigmoid activation\n", " self.decoder_model = Sequential()\n", " self.decoder_model.add(Dense(self.hidden_dim, activation='relu', input_dim=self.latent_dim))\n", " self.decoder_model.add(Dense(self.original_dim, activation='sigmoid'))\n", " self.rnaseq_reconstruct = self.decoder_model(self.z)\n", " \n", " def compile_vae(self):\n", " adam = optimizers.Adam(lr=self.learning_rate)\n", " vae_layer = CustomVariationalLayer(self.z_log_var_encoded,\n", " self.z_mean_encoded)([self.rnaseq_input, self.rnaseq_reconstruct])\n", " self.vae = Model(self.rnaseq_input, vae_layer)\n", " self.vae.compile(optimizer=adam, loss=None, loss_weights=[self.beta])\n", " \n", " def get_summary(self):\n", " self.vae.summary()\n", " \n", " def visualize_architecture(self, output_file):\n", " # Visualize the connections of the custom VAE model\n", " plot_model(self.vae, to_file=output_file)\n", " SVG(model_to_dot(self.vae).create(prog='dot', format='svg'))\n", " \n", " def train_vae(self):\n", " self.hist = self.vae.fit(np.array(rnaseq_train_df),\n", " shuffle=True,\n", " epochs=self.epochs,\n", " batch_size=self.batch_size,\n", " validation_data=(np.array(rnaseq_test_df), np.array(rnaseq_test_df)),\n", " callbacks=[WarmUpCallback(self.beta, self.kappa)])\n", " \n", " def visualize_training(self, output_file):\n", " # Visualize training performance\n", " history_df = pd.DataFrame(self.hist.history)\n", " ax = history_df.plot()\n", " ax.set_xlabel('Epochs')\n", " ax.set_ylabel('VAE Loss')\n", " fig = ax.get_figure()\n", " fig.savefig(output_file)\n", " \n", " def compress(self, df):\n", " # Model to compress input\n", " self.encoder = Model(self.rnaseq_input, self.z_mean_encoded)\n", " \n", " # Encode rnaseq into the hidden/latent representation - and save output\n", " encoded_df = self.encoder.predict_on_batch(df)\n", " encoded_df = pd.DataFrame(encoded_df, columns=range(1, self.latent_dim + 1),\n", " index=rnaseq_df.index)\n", " return encoded_df\n", " \n", " def get_decoder_weights(self):\n", " # build a generator that can sample from the learned distribution\n", " decoder_input = Input(shape=(self.latent_dim, )) # can generate from any sampled z vector\n", " _x_decoded_mean = self.decoder_model(decoder_input)\n", " self.decoder = Model(decoder_input, _x_decoded_mean)\n", " weights = []\n", " for layer in self.decoder.layers:\n", " weights.append(layer.get_weights())\n", " return(weights)\n", " \n", " def predict(self, df):\n", " return self.decoder.predict(np.array(df))\n", " \n", " def save_models(self, encoder_file, decoder_file):\n", " self.encoder.save(encoder_file)\n", " self.decoder.save(decoder_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Gene Expression Data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10459, 5000)\n" ] }, { "data": { "text/html": [ "
\n", " | RPS4Y1 | \n", "XIST | \n", "KRT5 | \n", "AGR2 | \n", "CEACAM5 | \n", "KRT6A | \n", "KRT14 | \n", "CEACAM6 | \n", "DDX3Y | \n", "KDM5D | \n", "... | \n", "FAM129A | \n", "C8orf48 | \n", "CDK5R1 | \n", "FAM81A | \n", "C13orf18 | \n", "GDPD3 | \n", "SMAGP | \n", "C2orf85 | \n", "POU5F1B | \n", "CHST2 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | \n", "0.678296 | \n", "0.289910 | \n", "0.034230 | \n", "0.0 | \n", "0.0 | \n", "0.084731 | \n", "0.031863 | \n", "0.037709 | \n", "0.746797 | \n", "0.687833 | \n", "... | \n", "0.440610 | \n", "0.428782 | \n", "0.732819 | \n", "0.634340 | \n", "0.580662 | \n", "0.294313 | \n", "0.458134 | \n", "0.478219 | \n", "0.168263 | \n", "0.638497 | \n", "
TCGA-02-0055-01 | \n", "0.200633 | \n", "0.654917 | \n", "0.181993 | \n", "0.0 | \n", "0.0 | \n", "0.100606 | \n", "0.050011 | \n", "0.092586 | \n", "0.103725 | \n", "0.140642 | \n", "... | \n", "0.620658 | \n", "0.363207 | \n", "0.592269 | \n", "0.602755 | \n", "0.610192 | \n", "0.374569 | \n", "0.722420 | \n", "0.271356 | \n", "0.160465 | \n", "0.602560 | \n", "
2 rows × 5000 columns
\n", "\n", " | 1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "10 | \n", "... | \n", "91 | \n", "92 | \n", "93 | \n", "94 | \n", "95 | \n", "96 | \n", "97 | \n", "98 | \n", "99 | \n", "100 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | \n", "1.804567 | \n", "0.0 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "0.0 | \n", "0.0 | \n", "2.082371 | \n", "0.000000 | \n", "0.0 | \n", "... | \n", "5.628737 | \n", "0.882983 | \n", "0.0 | \n", "0.000000 | \n", "1.976136 | \n", "1.912838 | \n", "3.621609 | \n", "0.000000 | \n", "1.947124 | \n", "1.840908 | \n", "
TCGA-02-0055-01 | \n", "0.635178 | \n", "0.0 | \n", "1.591518 | \n", "0.029515 | \n", "1.855888 | \n", "0.0 | \n", "0.0 | \n", "4.964176 | \n", "1.741375 | \n", "0.0 | \n", "... | \n", "1.160538 | \n", "0.000000 | \n", "0.0 | \n", "1.639663 | \n", "0.000000 | \n", "0.000000 | \n", "4.046312 | \n", "0.304179 | \n", "6.382465 | \n", "0.919127 | \n", "
2 rows × 100 columns
\n", "\n", " | 1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "10 | \n", "... | \n", "91 | \n", "92 | \n", "93 | \n", "94 | \n", "95 | \n", "96 | \n", "97 | \n", "98 | \n", "99 | \n", "100 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TCGA-02-0047-01 | \n", "0.000000 | \n", "0.0 | \n", "1.250155 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.95250 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
TCGA-02-0055-01 | \n", "0.556497 | \n", "0.0 | \n", "0.056864 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "... | \n", "0.0 | \n", "0.13956 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
2 rows × 100 columns
\n", "