{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Variational Autoencoder for pan-cancer gene expression\n", "\n", "**Gregory Way 2017**\n", "\n", "This script trains and outputs results for a [variational autoencoder (VAE)](https://arxiv.org/abs/1312.6114)\n", "applied to gene expression data across 33 different cancer-types from The Cancer Genome Atlas (TCGA).\n", "\n", "A VAE aproximates the data generating function for the cancer data and learns the lower dimensional manifold a tumor occupies in gene expression space. By compressing the gene expression space into lower dimensional space, the VAE would, ideally, learn biological principles, such as cancer hallmark pathway activations, that help explain how tumors are similar and different. The VAE is also a generative model with a latent space that can be interpolated to observe transitions between cancer states.\n", "\n", "The particular model trained in this notebook consists of gene expression input (5000 most variably expressed genes by median absolute deviation) compressed down into two length 100 vectors (mean and variance encoded spaces) which are made deterministic through the reparameterization trick of sampling an epsilon vector from the uniform distribution. The encoded layer is then decoded back to original 5000 dimensions through a single reconstruction layer. I included a layer of batch normalization in the encoding step to prevent dead nodes. The encoding scheme also uses relu activation while the decoder uses a sigmoid activation to enforce positive activations. All weights are glorot uniform initialized. \n", "\n", "Another trick used here to encourage manifold learning is _warm start_ as discussed in [Sonderby et al. 2016](https://arxiv.org/abs/1602.02282). With warm starts, we add a parameter _beta_, which controls the contribution of the KL divergence loss in the total VAE loss (reconstruction + (beta * KL)). In this setting, the model begins training deterministically as a vanilla autoencoder (_beta_ = 0) and slowly ramps up after each epoch linearly until _beta_ = 1. After a parameter sweep, we observed that kappa has little influence in training, therefore, we set _kappa_ = 1, which is a full VAE.\n", "\n", "Much of this script is inspired by the [keras variational_autoencoder.py example](https://github.com/fchollet/keras/blob/master/examples/variational_autoencoder.py)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "import tensorflow as tf\n", "from keras.layers import Input, Dense, Lambda, Layer, Activation\n", "from keras.layers.normalization import BatchNormalization\n", "from keras.models import Model\n", "from keras import backend as K\n", "from keras import metrics, optimizers\n", "from keras.callbacks import Callback\n", "import keras\n", "\n", "import pydot\n", "import graphviz\n", "from keras.utils import plot_model\n", "from keras_tqdm import TQDMNotebookCallback\n", "from IPython.display import SVG\n", "from keras.utils.vis_utils import model_to_dot" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.1.3\n" ] }, { "data": { "text/plain": [ "'1.4.0'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(keras.__version__)\n", "tf.__version__" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "plt.style.use('seaborn-notebook')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "sns.set(style=\"white\", color_codes=True)\n", "sns.set_context(\"paper\", rc={\"font.size\":14,\"axes.titlesize\":15,\"axes.labelsize\":20,\n", " 'xtick.labelsize':14, 'ytick.labelsize':14})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Functions and Classes\n", "\n", "This will facilitate connections between layers and also custom hyperparameters" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Function for reparameterization trick to make model differentiable\n", "def sampling(args):\n", " \n", " import tensorflow as tf\n", " # Function with args required for Keras Lambda function\n", " z_mean, z_log_var = args\n", "\n", " # Draw epsilon of the same shape from a standard normal distribution\n", " epsilon = K.random_normal(shape=tf.shape(z_mean), mean=0.,\n", " stddev=epsilon_std)\n", " \n", " # The latent vector is non-deterministic and differentiable\n", " # in respect to z_mean and z_log_var\n", " z = z_mean + K.exp(z_log_var / 2) * epsilon\n", " return z\n", "\n", "\n", "class CustomVariationalLayer(Layer):\n", " \"\"\"\n", " Define a custom layer that learns and performs the training\n", " This function is borrowed from:\n", " https://github.com/fchollet/keras/blob/master/examples/variational_autoencoder.py\n", " \"\"\"\n", " def __init__(self, **kwargs):\n", " # https://keras.io/layers/writing-your-own-keras-layers/\n", " self.is_placeholder = True\n", " super(CustomVariationalLayer, self).__init__(**kwargs)\n", "\n", " def vae_loss(self, x_input, x_decoded):\n", " reconstruction_loss = original_dim * metrics.binary_crossentropy(x_input, x_decoded)\n", " kl_loss = - 0.5 * K.sum(1 + z_log_var_encoded - K.square(z_mean_encoded) - \n", " K.exp(z_log_var_encoded), axis=-1)\n", " return K.mean(reconstruction_loss + (K.get_value(beta) * kl_loss))\n", "\n", " def call(self, inputs):\n", " x = inputs[0]\n", " x_decoded = inputs[1]\n", " loss = self.vae_loss(x, x_decoded)\n", " self.add_loss(loss, inputs=inputs)\n", " # We won't actually use the output.\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementing Warm-up as described in Sonderby et al. LVAE\n", "\n", "This is modified code from https://github.com/fchollet/keras/issues/2595" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class WarmUpCallback(Callback):\n", " def __init__(self, beta, kappa):\n", " self.beta = beta\n", " self.kappa = kappa\n", " # Behavior on each epoch\n", " def on_epoch_end(self, epoch, logs={}):\n", " if K.get_value(self.beta) <= 1:\n", " K.set_value(self.beta, K.get_value(self.beta) + self.kappa)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "np.random.seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Gene Expression Data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10459, 5000)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RPS4Y1XISTKRT5AGR2CEACAM5KRT6AKRT14CEACAM6DDX3YKDM5D...FAM129AC8orf48CDK5R1FAM81AC13orf18GDPD3SMAGPC2orf85POU5F1BCHST2
TCGA-02-0047-010.6782960.2899100.0342300.00.00.0847310.0318630.0377090.7467970.687833...0.4406100.4287820.7328190.6343400.5806620.2943130.4581340.4782190.1682630.638497
TCGA-02-0055-010.2006330.6549170.1819930.00.00.1006060.0500110.0925860.1037250.140642...0.6206580.3632070.5922690.6027550.6101920.3745690.7224200.2713560.1604650.602560
\n", "

2 rows × 5000 columns

\n", "
" ], "text/plain": [ " RPS4Y1 XIST KRT5 AGR2 CEACAM5 KRT6A \\\n", "TCGA-02-0047-01 0.678296 0.289910 0.034230 0.0 0.0 0.084731 \n", "TCGA-02-0055-01 0.200633 0.654917 0.181993 0.0 0.0 0.100606 \n", "\n", " KRT14 CEACAM6 DDX3Y KDM5D ... FAM129A \\\n", "TCGA-02-0047-01 0.031863 0.037709 0.746797 0.687833 ... 0.440610 \n", "TCGA-02-0055-01 0.050011 0.092586 0.103725 0.140642 ... 0.620658 \n", "\n", " C8orf48 CDK5R1 FAM81A C13orf18 GDPD3 SMAGP \\\n", "TCGA-02-0047-01 0.428782 0.732819 0.634340 0.580662 0.294313 0.458134 \n", "TCGA-02-0055-01 0.363207 0.592269 0.602755 0.610192 0.374569 0.722420 \n", "\n", " C2orf85 POU5F1B CHST2 \n", "TCGA-02-0047-01 0.478219 0.168263 0.638497 \n", "TCGA-02-0055-01 0.271356 0.160465 0.602560 \n", "\n", "[2 rows x 5000 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnaseq_file = os.path.join('data', 'pancan_scaled_zeroone_rnaseq.tsv.gz')\n", "rnaseq_df = pd.read_table(rnaseq_file, index_col=0)\n", "print(rnaseq_df.shape)\n", "rnaseq_df.head(2)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Split 10% test set randomly\n", "test_set_percent = 0.1\n", "rnaseq_test_df = rnaseq_df.sample(frac=test_set_percent)\n", "rnaseq_train_df = rnaseq_df.drop(rnaseq_test_df.index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize variables and hyperparameters" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Set hyper parameters\n", "original_dim = rnaseq_df.shape[1]\n", "latent_dim = 100\n", "\n", "batch_size = 50\n", "epochs = 50\n", "learning_rate = 0.0005\n", "\n", "epsilon_std = 1.0\n", "beta = K.variable(0)\n", "kappa = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Encoder" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Input place holder for RNAseq data with specific input size\n", "rnaseq_input = Input(shape=(original_dim, ))\n", "\n", "# Input layer is compressed into a mean and log variance vector of size `latent_dim`\n", "# Each layer is initialized with glorot uniform weights and each step (dense connections,\n", "# batch norm, and relu activation) are funneled separately\n", "# Each vector of length `latent_dim` are connected to the rnaseq input tensor\n", "z_mean_dense_linear = Dense(latent_dim, kernel_initializer='glorot_uniform')(rnaseq_input)\n", "z_mean_dense_batchnorm = BatchNormalization()(z_mean_dense_linear)\n", "z_mean_encoded = Activation('relu')(z_mean_dense_batchnorm)\n", "\n", "z_log_var_dense_linear = Dense(latent_dim, kernel_initializer='glorot_uniform')(rnaseq_input)\n", "z_log_var_dense_batchnorm = BatchNormalization()(z_log_var_dense_linear)\n", "z_log_var_encoded = Activation('relu')(z_log_var_dense_batchnorm)\n", "\n", "# return the encoded and randomly sampled z vector\n", "# Takes two keras layers as input to the custom sampling function layer with a `latent_dim` output\n", "z = Lambda(sampling, output_shape=(latent_dim, ))([z_mean_encoded, z_log_var_encoded])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decoder" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# The decoding layer is much simpler with a single layer and sigmoid activation\n", "decoder_to_reconstruct = Dense(original_dim, kernel_initializer='glorot_uniform', activation='sigmoid')\n", "rnaseq_reconstruct = decoder_to_reconstruct(z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Connect the encoder and decoder to make the VAE\n", "\n", "The `CustomVariationalLayer()` includes the VAE loss function (reconstruction + (beta * KL)), which is what will drive our model to learn an interpretable representation of gene expression space.\n", "\n", "The VAE is compiled with an Adam optimizer and built-in custom loss function. The `loss_weights` parameter ensures beta is updated at each epoch end callback" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "input_1 (InputLayer) (None, 5000) 0 \n", "__________________________________________________________________________________________________\n", "dense_1 (Dense) (None, 100) 500100 input_1[0][0] \n", "__________________________________________________________________________________________________\n", "dense_2 (Dense) (None, 100) 500100 input_1[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization_1 (BatchNor (None, 100) 400 dense_1[0][0] \n", "__________________________________________________________________________________________________\n", "batch_normalization_2 (BatchNor (None, 100) 400 dense_2[0][0] \n", "__________________________________________________________________________________________________\n", "activation_1 (Activation) (None, 100) 0 batch_normalization_1[0][0] \n", "__________________________________________________________________________________________________\n", "activation_2 (Activation) (None, 100) 0 batch_normalization_2[0][0] \n", "__________________________________________________________________________________________________\n", "lambda_1 (Lambda) (None, 100) 0 activation_1[0][0] \n", " activation_2[0][0] \n", "__________________________________________________________________________________________________\n", "dense_3 (Dense) (None, 5000) 505000 lambda_1[0][0] \n", "__________________________________________________________________________________________________\n", "custom_variational_layer_1 (Cus [(None, 5000), (None 0 input_1[0][0] \n", " dense_3[0][0] \n", "==================================================================================================\n", "Total params: 1,506,000\n", "Trainable params: 1,505,600\n", "Non-trainable params: 400\n", "__________________________________________________________________________________________________\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/gway/anaconda3/envs/tybalt/lib/python3.5/site-packages/ipykernel/__main__.py:4: UserWarning: Output \"custom_variational_layer_1\" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to \"custom_variational_layer_1\" during training.\n" ] } ], "source": [ "adam = optimizers.Adam(lr=learning_rate)\n", "vae_layer = CustomVariationalLayer()([rnaseq_input, rnaseq_reconstruct])\n", "vae = Model(rnaseq_input, vae_layer)\n", "vae.compile(optimizer=adam, loss=None, loss_weights=[beta])\n", "\n", "vae.summary()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "G\n", "\n", "\n", "139838144235112\n", "\n", "input_1: InputLayer\n", "\n", "\n", "139838144235168\n", "\n", "dense_1: Dense\n", "\n", "\n", "139838144235112->139838144235168\n", "\n", "\n", "\n", "\n", "139838144235952\n", "\n", "dense_2: Dense\n", "\n", "\n", "139838144235112->139838144235952\n", "\n", "\n", "\n", "\n", "139838143272776\n", "\n", "custom_variational_layer_1: CustomVariationalLayer\n", "\n", "\n", "139838144235112->139838143272776\n", "\n", "\n", "\n", "\n", "139838144235728\n", "\n", "batch_normalization_1: BatchNormalization\n", "\n", "\n", "139838144235168->139838144235728\n", "\n", "\n", "\n", "\n", "139838144235672\n", "\n", "batch_normalization_2: BatchNormalization\n", "\n", "\n", "139838144235952->139838144235672\n", "\n", "\n", "\n", "\n", "139838144236008\n", "\n", "activation_1: Activation\n", "\n", "\n", "139838144235728->139838144236008\n", "\n", "\n", "\n", "\n", "139838053071168\n", "\n", "activation_2: Activation\n", "\n", "\n", "139838144235672->139838053071168\n", "\n", "\n", "\n", "\n", "139838053071224\n", "\n", "lambda_1: Lambda\n", "\n", "\n", "139838144236008->139838053071224\n", "\n", "\n", "\n", "\n", "139838053071168->139838053071224\n", "\n", "\n", "\n", "\n", "139838144989392\n", "\n", "dense_3: Dense\n", "\n", "\n", "139838053071224->139838144989392\n", "\n", "\n", "\n", "\n", "139838144989392->139838143272776\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Visualize the connections of the custom VAE model\n", "output_model_file = os.path.join('figures', 'onehidden_vae_architecture.png')\n", "plot_model(vae, to_file=output_model_file)\n", "\n", "SVG(model_to_dot(vae).create(prog='dot', format='svg'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model\n", "\n", "The training data is shuffled after every epoch and 10% of the data is heldout for calculating validation loss." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32eaf8105c054fe29c7e169001f7391a" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cfff123fe9214f5e93045dbd132ee89a" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "084aae2a0cdc42bc9651b9fe6e3a4389" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5bc6ded663ae4b1bbe60e08a3fca0183" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2313a2be37a548ceb4a540a06bdcdd9d" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a55bd69ac6324059be14f1a85e9be5cd" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fab18df7a8c9469a8151b56d5b94d0ea" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8be0d3b732a046058ea25d5da1987756" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "67a7b8a7dafc4989948c7f29eb664f2e" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b9b92de5cbfb47129b55935d3993200a" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9f216e3711f64ea3b49fba5080fa4974" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94953cafa07c4df788e10aa2c672c7d5" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c522618c228b42dc924e29b7ad02c1a3" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e77ffbaae3b14aa69f95f12c167b4fad" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ba1faf6668f2479e850543b5713fec74" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6afbe62e1de34e8da03330098e8a92f5" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4d66d4a8749c4b95bf0d529871502c0c" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d23d51725ba4a72829cb2cf2cd3b160" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "66892cb856214933b6cba9d0bd5466bb" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62206b37b34c4a3fbda5afc9cc268f41" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ddb7fea9d8e84a5fa513ba0860f41a39" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a1357e2c48f94a1497067f49d8c6449e" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f493acbb35fc43509ddf292c9dafbef0" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f2cd21f8848d4615af9e9a5327e91483" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ab4eaeb2562d4923abb78513c27b5a0e" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b2f826d6b05142398e076648e323102a" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "48304224df8547f0bf7bac17214f8a7e" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f5bddc574ca74aab8bae32c63d71bd48" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ff4197bd1a145a4be17c94e7ac3cd01" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bd8f852bf3f84d0bb8350bae7f1adb8a" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fb8500353ea04b0fb3b586896f546d27" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7cf8ea29afa147edb1667993fef0474c" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d766fba954174efb928c020d503f6d1f" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5582c58bb4874fe88f078a8a8c7b12e4" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9582462a42d94bde90cb74002ebb5830" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0250c18f5b604757a6498cc04c2aa02f" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "22605e04210b4eff87c7eb05eef62738" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b13f543eb8634c298f1ae7808014fb96" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5b16d05a91e6455cb2ff459e4b5da224" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "84d5356019594df9b83542b2b5c5ef2d" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4549598c4f7e4ac49db7d3919c195cee" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f8eb61ade6234285b6e0989919da7f46" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "68aa9b8711934d8a851308342b501fb3" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e40fc964a8c047fbb99ad5dcbf9fe6bf" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a2dcf5b28f48415e9507b3061616e00b" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "faacf926d6f74100a755d4b13a9076a6" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3d163ca517a0466586fa1babcacd1693" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da0144ba7e5f462497f3799ba6af59a5" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1ab15a38b4f34cc294691ab58062655d" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62dbeefe5f8845ac8efea116a8e7c409" } }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9eec6dc240f24dcdbf3ab103d9f261b4" } }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "CPU times: user 12min 51s, sys: 1min 34s, total: 14min 25s\n", "Wall time: 5min 16s\n" ] } ], "source": [ "%%time\n", "hist = vae.fit(np.array(rnaseq_train_df),\n", " shuffle=True,\n", " epochs=epochs,\n", " verbose=0,\n", " batch_size=batch_size,\n", " validation_data=(np.array(rnaseq_test_df), None),\n", " callbacks=[WarmUpCallback(beta, kappa),\n", " TQDMNotebookCallback(leave_inner=True, leave_outer=True)])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize training performance\n", "history_df = pd.DataFrame(hist.history)\n", "hist_plot_file = os.path.join('figures', 'onehidden_vae_training.pdf')\n", "ax = history_df.plot()\n", "ax.set_xlabel('Epochs')\n", "ax.set_ylabel('VAE Loss')\n", "fig = ax.get_figure()\n", "fig.savefig(hist_plot_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile and output trained models\n", "\n", "We are interested in:\n", "\n", "1. The model to encode/compress the input gene expression data\n", " * Can be possibly used to compress other tumors\n", "2. The model to decode/decompress the latent space back into gene expression space\n", " * This is our generative model\n", "3. The latent space compression of all pan cancer TCGA samples\n", " * Non-linear reduced dimension representation of tumors can be used as features for various tasks\n", " * Supervised learning tasks predicting specific gene inactivation events\n", " * Interpolating across this space to observe how gene expression changes between two cancer states\n", "4. The weights used to compress each latent node\n", " * Potentially indicate learned biology differentially activating tumors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder model" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Model to compress input\n", "encoder = Model(rnaseq_input, z_mean_encoded)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Encode rnaseq into the hidden/latent representation - and save output\n", "encoded_rnaseq_df = encoder.predict_on_batch(rnaseq_df)\n", "encoded_rnaseq_df = pd.DataFrame(encoded_rnaseq_df, index=rnaseq_df.index)\n", "\n", "encoded_rnaseq_df.columns.name = 'sample_id'\n", "encoded_rnaseq_df.columns = encoded_rnaseq_df.columns + 1\n", "encoded_file = os.path.join('data', 'encoded_rnaseq_onehidden_warmup_batchnorm.tsv')\n", "encoded_rnaseq_df.to_csv(encoded_file, sep='\\t')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decoder (generative) model" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# build a generator that can sample from the learned distribution\n", "decoder_input = Input(shape=(latent_dim, )) # can generate from any sampled z vector\n", "_x_decoded_mean = decoder_to_reconstruct(decoder_input)\n", "decoder = Model(decoder_input, _x_decoded_mean)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the encoder/decoder models for future investigation" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "encoder_model_file = os.path.join('models', 'encoder_onehidden_vae.hdf5')\n", "decoder_model_file = os.path.join('models', 'decoder_onehidden_vae.hdf5')\n", "\n", "encoder.save(encoder_model_file)\n", "decoder.save(decoder_model_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Interpretation - Sanity Check\n", "\n", "\n", "### Observe the distribution of node activations.\n", "\n", "We want to ensure that the model is learning a distribution of feature activations, and not zeroing out features." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sample_id\n", "82 29584.533203\n", "5 29578.066406\n", "28 24349.054688\n", "16 24086.214844\n", "6 24084.980469\n", "63 23750.429688\n", "8 23380.603516\n", "57 23047.580078\n", "87 23010.695312\n", "37 22798.029297\n", "dtype: float32\n" ] }, { "data": { "text/plain": [ "sample_id\n", "91 14282.995117\n", "45 14082.565430\n", "34 13749.231445\n", "18 13509.525391\n", "97 13373.916992\n", "32 13035.963867\n", "92 12693.304688\n", "2 12593.957031\n", "20 11392.033203\n", "4 10859.074219\n", "dtype: float32" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# What are the most and least activated nodes\n", "sum_node_activity = encoded_rnaseq_df.sum(axis=0).sort_values(ascending=False)\n", "\n", "# Top 10 most active nodes\n", "print(sum_node_activity.head(10))\n", "\n", "# Bottom 10 least active nodes\n", "sum_node_activity.tail(10)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Histogram of node activity for all 100 latent features\n", "sum_node_activity.hist()\n", "plt.xlabel('Activation Sum')\n", "plt.ylabel('Count');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What does an example distribution of two latent features look like?" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Example of node activation distribution for the first two latent features\n", "plt.figure(figsize=(6, 6))\n", "plt.scatter(encoded_rnaseq_df.iloc[:, 1], encoded_rnaseq_df.iloc[:, 2])\n", "plt.xlabel('Latent Feature 1')\n", "plt.xlabel('Latent Feature 2');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Observe reconstruction fidelity" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RPS4Y1XISTKRT5AGR2CEACAM5KRT6AKRT14CEACAM6DDX3YKDM5D...FAM129AC8orf48CDK5R1FAM81AC13orf18GDPD3SMAGPC2orf85POU5F1BCHST2
TCGA-02-0047-010.6916380.1965290.1942040.0357440.0277840.0556220.0642140.0430930.7219250.690764...0.4513670.4861860.7884870.6800680.5751400.3246110.3855450.6058590.2101920.644549
TCGA-02-0055-010.0991890.5924600.1886170.1066640.0440370.1006980.1157200.0695240.0756630.055188...0.5645060.5190690.6466980.5904750.6243370.3857870.5701500.2462590.1817080.649352
\n", "

2 rows × 5000 columns

\n", "
" ], "text/plain": [ " RPS4Y1 XIST KRT5 AGR2 CEACAM5 KRT6A \\\n", "TCGA-02-0047-01 0.691638 0.196529 0.194204 0.035744 0.027784 0.055622 \n", "TCGA-02-0055-01 0.099189 0.592460 0.188617 0.106664 0.044037 0.100698 \n", "\n", " KRT14 CEACAM6 DDX3Y KDM5D ... FAM129A \\\n", "TCGA-02-0047-01 0.064214 0.043093 0.721925 0.690764 ... 0.451367 \n", "TCGA-02-0055-01 0.115720 0.069524 0.075663 0.055188 ... 0.564506 \n", "\n", " C8orf48 CDK5R1 FAM81A C13orf18 GDPD3 SMAGP \\\n", "TCGA-02-0047-01 0.486186 0.788487 0.680068 0.575140 0.324611 0.385545 \n", "TCGA-02-0055-01 0.519069 0.646698 0.590475 0.624337 0.385787 0.570150 \n", "\n", " C2orf85 POU5F1B CHST2 \n", "TCGA-02-0047-01 0.605859 0.210192 0.644549 \n", "TCGA-02-0055-01 0.246259 0.181708 0.649352 \n", "\n", "[2 rows x 5000 columns]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# How well does the model reconstruct the input RNAseq data\n", "input_rnaseq_reconstruct = decoder.predict(np.array(encoded_rnaseq_df))\n", "input_rnaseq_reconstruct = pd.DataFrame(input_rnaseq_reconstruct, index=rnaseq_df.index,\n", " columns=rnaseq_df.columns)\n", "input_rnaseq_reconstruct.head(2)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
gene meangene abs(sum)
PPAN-P2RY11-0.0203640.230511
GSTT10.0247100.229753
GSTM10.0056500.216558
TBC1D3G-0.0103980.194532
RPS280.0122420.176380
\n", "
" ], "text/plain": [ " gene mean gene abs(sum)\n", "PPAN-P2RY11 -0.020364 0.230511\n", "GSTT1 0.024710 0.229753\n", "GSTM1 0.005650 0.216558\n", "TBC1D3G -0.010398 0.194532\n", "RPS28 0.012242 0.176380" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reconstruction_fidelity = rnaseq_df - input_rnaseq_reconstruct\n", "\n", "gene_mean = reconstruction_fidelity.mean(axis=0)\n", "gene_abssum = reconstruction_fidelity.abs().sum(axis=0).divide(rnaseq_df.shape[0])\n", "gene_summary = pd.DataFrame([gene_mean, gene_abssum], index=['gene mean', 'gene abs(sum)']).T\n", "gene_summary.sort_values(by='gene abs(sum)', ascending=False).head()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Mean of gene reconstruction vs. absolute reconstructed difference per sample\n", "g = sns.jointplot('gene mean', 'gene abs(sum)', data=gene_summary, stat_func=None);" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:tybalt]", "language": "python", "name": "conda-env-tybalt-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.5" } }, "nbformat": 4, "nbformat_minor": 1 }