{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import keras\n", "from keras.models import Sequential, Model, load_model\n", "\n", "from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda\n", "from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute\n", "from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply\n", "from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback\n", "from keras import regularizers\n", "from keras import backend as K\n", "import keras.losses\n", "\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "\n", "import isolearn.keras as iso\n", "\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "import logging\n", "logging.getLogger('tensorflow').setLevel(logging.ERROR)\n", "\n", "import pandas as pd\n", "\n", "import os\n", "import pickle\n", "import numpy as np\n", "\n", "import scipy.sparse as sp\n", "import scipy.io as spio\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import isolearn.keras as iso\n", "\n", "from seqprop.visualization import *\n", "from seqprop.generator import *\n", "from seqprop.predictor import *\n", "from seqprop.optimizer import *\n", "\n", "from definitions.deepsea import load_saved_predictor\n", "\n", "import warnings\n", "warnings.simplefilter(\"ignore\")\n", "\n", "from keras.backend.tensorflow_backend import set_session\n", "\n", "def contain_tf_gpu_mem_usage() :\n", " config = tf.ConfigProto()\n", " config.gpu_options.allow_growth = True\n", " sess = tf.Session(config=config)\n", " set_session(sess)\n", "\n", "contain_tf_gpu_mem_usage()\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "#Define target isoform loss function\n", "def get_earthmover_loss(target_output_ix, pwm_start=0, pwm_end=70, pwm_target_bits=1.8, pwm_entropy_weight=0.0) :\n", " \n", " punish_c = 0.0\n", " punish_g = 0.0\n", " punish_aa = 0.0\n", " \n", " entropy_mse = get_margin_entropy(pwm_start=pwm_start, pwm_end=pwm_end, min_bits=pwm_target_bits)\n", " \n", " punish_c_func = get_punish_c(pwm_start=pwm_start, pwm_end=pwm_end)\n", " punish_g_func = get_punish_g(pwm_start=pwm_start, pwm_end=pwm_end)\n", " punish_aa_func = get_punish_aa(pwm_start=pwm_start, pwm_end=pwm_end)\n", "\n", " def loss_func(predictor_outputs) :\n", " pwm_logits, pwm, sampled_pwm, pred_bind, pred_score = predictor_outputs\n", "\n", " #Specify costs\n", " fitness_loss = -1.0 * K.mean(pred_score[..., target_output_ix], axis=0)\n", " \n", " seq_loss = 0.0\n", " seq_loss += punish_c * K.mean(punish_c_func(sampled_pwm), axis=0)\n", " seq_loss += punish_g * K.mean(punish_g_func(sampled_pwm), axis=0)\n", " seq_loss += punish_aa * K.mean(punish_aa_func(sampled_pwm), axis=0)\n", " \n", " entropy_loss = pwm_entropy_weight * entropy_mse(pwm)\n", " \n", " #Compute total loss\n", " total_loss = fitness_loss + seq_loss + entropy_loss\n", "\n", " return K.reshape(K.sum(total_loss, axis=0), (1,))\n", " \n", " def val_loss_func(predictor_outputs) :\n", " pwm_logits, pwm, sampled_pwm, pred_bind, pred_score = predictor_outputs\n", "\n", " #Specify costs\n", " fitness_loss = -1.0 * K.mean(pred_score[..., target_output_ix], axis=0)\n", " \n", " seq_loss = 0.0\n", " seq_loss += punish_c * K.mean(punish_c_func(sampled_pwm), axis=0)\n", " seq_loss += punish_g * K.mean(punish_g_func(sampled_pwm), axis=0)\n", " seq_loss += punish_aa * K.mean(punish_aa_func(sampled_pwm), axis=0)\n", " \n", " entropy_loss = pwm_entropy_weight * entropy_mse(pwm)\n", " \n", " #Compute total loss\n", " total_loss = fitness_loss + seq_loss + entropy_loss\n", "\n", " return K.reshape(K.mean(total_loss, axis=0), (1,))\n", " \n", " return loss_func, val_loss_func\n", "\n", "\n", "def get_nop_transform() :\n", " \n", " def _transform_func(pwm) :\n", " \n", " return pwm\n", " \n", " return _transform_func\n", "\n", "class ValidationCallback(Callback):\n", " def __init__(self, val_name, val_loss_model, val_steps) :\n", " self.val_name = val_name\n", " self.val_loss_model = val_loss_model\n", " self.val_steps = val_steps\n", " \n", " self.val_loss_history = []\n", " \n", " #Track val loss\n", " self.val_loss_history.append(self.val_loss_model.predict(x=None, steps=self.val_steps)[0])\n", " \n", " def on_batch_end(self, batch, logs={}) :\n", " #Track val loss\n", " val_loss_value = self.val_loss_model.predict(x=None, steps=self.val_steps)[0]\n", " self.val_loss_history.append(val_loss_value)\n", "\n", "#Function for running SeqProp on a set of objectives to optimize\n", "def run_seqprop(target_output_ixs, sequence_templates, loss_funcs, val_loss_funcs, transform_funcs, n_sequences=1, n_samples=1, n_valid_samples=1, eval_mode='sample', normalize_logits=False, n_epochs=10, steps_per_epoch=100) :\n", " \n", " n_objectives = len(sequence_templates)\n", " \n", " seqprop_predictors = []\n", " valid_monitors = []\n", " train_histories = []\n", " valid_histories = []\n", " \n", " for obj_ix in range(n_objectives) :\n", " print(\"Optimizing objective \" + str(obj_ix) + '...')\n", " \n", " sequence_template = sequence_templates[obj_ix]\n", " loss_func = loss_funcs[obj_ix]\n", " val_loss_func = val_loss_funcs[obj_ix]\n", " transform_func = transform_funcs[obj_ix]\n", " target_output_ix = target_output_ixs[obj_ix]\n", " \n", " #Build Generator Network\n", " _, seqprop_generator = build_generator(seq_length=len(sequence_template), n_sequences=n_sequences, n_samples=n_samples, sequence_templates=[sequence_template * n_sequences], batch_normalize_pwm=normalize_logits, pwm_transform_func=transform_func, validation_sample_mode='sample')\n", " #for layer in seqprop_generator.layers :\n", " # if 'policy' not in layer.name :\n", " # layer.name += \"_trainversion\"\n", " _, valid_generator = build_generator(seq_length=len(sequence_template), n_sequences=n_sequences, n_samples=n_valid_samples, sequence_templates=[sequence_template * n_sequences], batch_normalize_pwm=normalize_logits, pwm_transform_func=None, validation_sample_mode='sample', master_generator=seqprop_generator)\n", " for layer in valid_generator.layers :\n", " #if 'policy' not in layer.name :\n", " layer.name += \"_valversion\"\n", " \n", " #Build Predictor Network and hook it on the generator PWM output tensor\n", " _, seqprop_predictor = build_predictor(seqprop_generator, load_saved_predictor(model_path, library_context=None), n_sequences=n_sequences, n_samples=n_samples, eval_mode=eval_mode)\n", " #for layer in seqprop_predictor.layers :\n", " # if '_trainversion' not in layer.name and 'policy' not in layer.name :\n", " # layer.name += \"_trainversion\"\n", " _, valid_predictor = build_predictor(valid_generator, load_saved_predictor(model_path, library_context=None), n_sequences=n_sequences, n_samples=n_valid_samples, eval_mode='sample')\n", " for layer in valid_predictor.layers :\n", " if '_valversion' not in layer.name :# and 'policy' not in layer.name :\n", " layer.name += \"_valversion\"\n", " \n", " #Build Loss Model (In: Generator seed, Out: Loss function)\n", " _, loss_model = build_loss_model(seqprop_predictor, loss_func)\n", " _, valid_loss_model = build_loss_model(valid_predictor, val_loss_func)\n", " \n", " #Specify Optimizer to use\n", " #opt = keras.optimizers.SGD(lr=0.5)\n", " #opt = keras.optimizers.SGD(lr=0.1, momentum=0.9, decay=0, nesterov=True)\n", " opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999)\n", "\n", " #Compile Loss Model (Minimize self)\n", " loss_model.compile(loss=lambda true, pred: pred, optimizer=opt)\n", "\n", " def get_logit(p) :\n", " return np.log(p / (1. - p))\n", " \n", " #Specify callback entities\n", " #measure_func = lambda pred_outs: np.mean(get_logit(np.expand_dims(pred_outs[0], axis=0) if len(pred_outs[0].shape) <= 2 else pred_outs[0]), axis=0)\n", " measure_func = lambda pred_outs: np.mean(np.expand_dims(np.expand_dims(pred_outs[1][..., target_output_ix], axis=-1), axis=0) if len(pred_outs[1].shape) <= 2 else np.expand_dims(pred_outs[1][..., target_output_ix], axis=-1), axis=0)\n", " \n", " #train_monitor = FlexibleSeqPropMonitor(predictor=seqprop_predictor, plot_on_train_end=False, plot_every_epoch=False, track_every_step=True, measure_func=measure_func, measure_name='Binding Log Odds', plot_pwm_start=500, plot_pwm_end=700, sequence_template=sequence_template, plot_pwm_indices=np.arange(n_sequences).tolist(), figsize=(12, 1.0))\n", " valid_monitor = FlexibleSeqPropMonitor(predictor=valid_predictor, plot_on_train_end=True, plot_every_epoch=False, track_every_step=True, measure_func=measure_func, measure_name='Binding Log Odds', plot_pwm_start=425, plot_pwm_end=575, sequence_template=sequence_template, plot_pwm_indices=np.arange(n_sequences).tolist(), figsize=(12, 1.0))\n", " \n", " train_history = ValidationCallback('loss', loss_model, 1)\n", " valid_history = ValidationCallback('val_loss', valid_loss_model, 1)\n", " \n", " callbacks =[\n", " #EarlyStopping(monitor='loss', min_delta=0.001, patience=5, verbose=0, mode='auto'),\n", " valid_monitor,\n", " train_history,\n", " valid_history\n", " ]\n", " \n", " #Fit Loss Model\n", " _ = loss_model.fit(\n", " [], np.ones((1, 1)), #Dummy training example\n", " epochs=n_epochs,\n", " steps_per_epoch=steps_per_epoch,\n", " callbacks=callbacks\n", " )\n", " \n", " valid_monitor.predictor = None\n", " train_history.val_loss_model = None\n", " valid_history.val_loss_model = None\n", " \n", " seqprop_predictors.append(seqprop_predictor)\n", " valid_monitors.append(valid_monitor)\n", " train_histories.append(train_history)\n", " valid_histories.append(valid_history)\n", "\n", " return seqprop_predictors, valid_monitors, train_histories, valid_histories\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#Specfiy file path to pre-trained predictor network\n", "\n", "save_dir = os.path.join(os.getcwd(), '')\n", "model_name = 'deepsea_keras.h5'\n", "model_path = os.path.join(save_dir, model_name)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "