{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import keras\n", "from keras.models import Sequential, Model, load_model\n", "\n", "from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda\n", "from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute\n", "from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply\n", "from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback\n", "from keras import regularizers\n", "from keras import backend as K\n", "import keras.losses\n", "\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops\n", "\n", "import isolearn.keras as iso\n", "\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", "import logging\n", "logging.getLogger('tensorflow').setLevel(logging.ERROR)\n", "\n", "import pandas as pd\n", "\n", "import os\n", "import pickle\n", "import numpy as np\n", "\n", "import scipy.sparse as sp\n", "import scipy.io as spio\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import isolearn.keras as iso\n", "\n", "from seqprop.visualization import *\n", "from seqprop.generator import *\n", "from seqprop.predictor import *\n", "from seqprop.optimizer import *\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "\n", "def load_dummy_predictor() :\n", "\n", " def _initialize_predictor_weights(predictor_model) :\n", " #Load pre-trained model\n", " print(\"I am a placeholder function.\")\n", "\n", " def _load_predictor_func(sequence_input) :\n", " \n", " def score_motif(pwm) :\n", " #Score GTT motifs in the first 30nt of input sequence (placeholder example of a predictor network)\n", " motif_score = K.sum(pwm[..., :30-2, 2, :] * pwm[..., 1:30-1, 3, :] * pwm[..., 2:30, 3, :], axis=-2)\n", " \n", " return motif_score\n", " \n", " counting_layer = Lambda(lambda x: score_motif(x))\n", " pred_score = counting_layer(sequence_input)\n", "\n", " predictor_inputs = []\n", " predictor_outputs = [pred_score]\n", "\n", " return predictor_inputs, predictor_outputs, _initialize_predictor_weights\n", "\n", " return _load_predictor_func\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "\n", "#Define target isoform loss function\n", "def get_example_loss(target_score=10, opt_start=0, opt_end=30, opt_target_bits=1.8, entropy_weight=0.0) :\n", " \n", " opt_entropy_mse = get_target_entropy_sme(pwm_start=opt_start, pwm_end=opt_end, target_bits=opt_target_bits)\n", "\n", " def loss_func(predictor_outputs) :\n", " pwm_logits, pwm, sampled_pwm, pred_score = predictor_outputs\n", "\n", " #Specify costs\n", " score_loss = 1.0 * K.mean((pred_score - target_score)**2, axis=0)\n", " \n", " seq_loss = 0.0\n", " \n", " entropy_loss = entropy_weight * opt_entropy_mse(pwm)\n", " \n", " #Compute total loss\n", " total_loss = score_loss + seq_loss + entropy_loss\n", "\n", " return K.sum(total_loss, axis=0)\n", " \n", " return loss_func\n", "\n", "\n", "def get_revcomp_transform() :\n", " \n", " def transform_func(pwm) :\n", " \n", " pwm_opt = pwm[..., :30, :, :]\n", " \n", " a_band = K.expand_dims(pwm_opt[..., :, 0, :], axis=-2)\n", " c_band = K.expand_dims(pwm_opt[..., :, 1, :], axis=-2)\n", " g_band = K.expand_dims(pwm_opt[..., :, 2, :], axis=-2)\n", " t_band = K.expand_dims(pwm_opt[..., :, 3, :], axis=-2)\n", " \n", " rev_comp_pwm = K.concatenate([\n", " t_band,\n", " g_band,\n", " c_band,\n", " a_band\n", " ], axis=-2)[..., ::-1, :, :]\n", " \n", " return K.concatenate([pwm_opt, rev_comp_pwm], axis=-3)\n", " \n", " return transform_func\n", "\n", "def get_revcomp_transform_opt() :\n", " \n", " def transform_func(pwm) :\n", " \n", " return K.concatenate([pwm[..., :30, :, :], pwm[..., :30, :, :][..., ::-1, ::-1, :]], axis=-3)\n", " \n", " return transform_func\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "\n", "#Function for running SeqProp on a set of objectives to optimize\n", "def run_seqprop(sequence_templates, loss_funcs, transform_funcs, n_sequences=1, n_samples=1, eval_mode='pwm', n_epochs=10, steps_per_epoch=100) :\n", " \n", " n_objectives = len(sequence_templates)\n", " \n", " optimized_pwms = []\n", " optimized_scores = []\n", " \n", " for obj_ix in range(n_objectives) :\n", " print(\"Optimizing objective \" + str(obj_ix) + '...')\n", " \n", " sequence_template = sequence_templates[obj_ix]\n", " loss_func = loss_funcs[obj_ix]\n", " transform_func = transform_funcs[obj_ix]\n", " \n", " #Build Generator Network\n", " _, seqprop_generator = build_generator(seq_length=60, n_sequences=n_sequences, n_samples=n_samples, sequence_templates=[sequence_template * n_sequences], batch_normalize_pwm=True, pwm_transform_func=transform_func)\n", " \n", " #Build Predictor Network and hook it on the generator PWM output tensor\n", " _, seqprop_predictor = build_predictor(seqprop_generator, load_dummy_predictor(), n_sequences=n_sequences, n_samples=n_samples, eval_mode=eval_mode)\n", " \n", " #Build Loss Model (In: Generator seed, Out: Loss function)\n", " _, loss_model = build_loss_model(seqprop_predictor, loss_func)\n", " \n", " #Specify Optimizer to use\n", " #opt = keras.optimizers.SGD(lr=0.1)\n", " opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999)\n", "\n", " #Compile Loss Model (Minimize self)\n", " loss_model.compile(loss=lambda true, pred: pred, optimizer=opt)\n", "\n", " measure_func = lambda pred_outs: pred_outs[0][0, :, :] if len(pred_outs[0].shape) > 2 else pred_outs[0]\n", " \n", " #Specify callback entities\n", " callbacks =[\n", " EarlyStopping(monitor='loss', min_delta=0.001, patience=5, verbose=0, mode='auto'),\n", " FlexibleSeqPropMonitor(predictor=seqprop_predictor, plot_every_epoch=False, track_every_step=True, measure_func=measure_func, measure_name='Motif Score', plot_pwm_start=0, plot_pwm_end=60, sequence_template=sequence_template, plot_pwm_indices=[0])\n", " ]\n", "\n", " #Fit Loss Model\n", " train_history = loss_model.fit(\n", " [], np.ones((1, 1)), #Dummy training example\n", " epochs=n_epochs,\n", " steps_per_epoch=steps_per_epoch,\n", " callbacks=callbacks\n", " )\n", " \n", " #Retrieve optimized PWMs and predicted cleavage distributionns\n", " _, optimized_pwm, _, pred_score = seqprop_predictor.predict(x=None, steps=1)\n", " \n", " optimized_pwms.append(optimized_pwm)\n", " optimized_scores.append(pred_score)\n", "\n", " return optimized_pwms, optimized_scores\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running optimization experiment\n", "Optimizing objective 0...\n", "I am a placeholder function.\n", "Epoch 1/1\n", "2000/2000 [==============================] - 9s 4ms/step - loss: 120.5265\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "seq_template = 'NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNTTTTTTTTTTXXXXXXXXXXXXXXXXXXXX'\n", "library_context = 'simple'\n", "\n", "#Run SeqProp Optimization\n", "\n", "print(\"Running optimization experiment\")\n", "\n", "#Number of PWMs to generate per objective\n", "n_sequences = 10\n", "#Number of One-hot sequences to sample from the PWM at each grad step\n", "n_samples = 1\n", "#Number of epochs per objective to optimize\n", "n_epochs = 1#10\n", "#Number of steps (grad updates) per epoch\n", "steps_per_epoch = 2000\n", "\n", "#Either 'pwm' for relaxed/continuous pwm input sent to the predictor, or 'sample' for proper sampled onehots\n", "eval_mode='sample'\n", "\n", "sequence_templates = [\n", " seq_template\n", "]\n", "\n", "losses = [\n", " get_example_loss(\n", " target_score=10.,\n", " opt_start=0,\n", " opt_end=30,\n", " opt_target_bits=1.8,\n", " entropy_weight=0.0\n", " )\n", "]\n", "\n", "transforms = [\n", " get_revcomp_transform_opt()\n", "]\n", "\n", "pwms, scores = run_seqprop(sequence_templates, losses, transforms, n_sequences, n_samples, eval_mode, n_epochs, steps_per_epoch)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:aparent]", "language": "python", "name": "conda-env-aparent-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }