{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "import keras\n",
    "from keras.models import Sequential, Model, load_model\n",
    "from keras import backend as K\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import isolearn.keras as iso\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.cm as cm\n",
    "\n",
    "from aparent.predictor import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/jlinder2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From /home/jlinder2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jlinder2/anaconda3/envs/tensorflow/lib/python3.6/site-packages/keras/engine/saving.py:292: UserWarning: No training configuration found in save file: the model was *not* compiled. Compile it manually.\n",
      "  warnings.warn('No training configuration found in save file: '\n"
     ]
    }
   ],
   "source": [
    "#Load base APARENT model and input encoder\n",
    "\n",
    "aparent_model = load_model('../saved_models/aparent_large_lessdropout_all_libs_no_sampleweights.h5')\n",
    "aparent_encoder = get_aparent_encoder(lib_bias=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(seq) = 1610\n"
     ]
    }
   ],
   "source": [
    "#Example: pA sites from APADB (gene = PSMC6) and some non-pA repeat sequence inbetween\n",
    "\n",
    "#Proximal and Distal PAS Sequences\n",
    "seq_prox = 'AGATAGTGGTATAAGAAAGCATTTCTTATGACTTATTTTGTATCATTTGTTTTCCTCATCTAAAAAGTTGAATAAAATCTGTTTGATTCAGTTCTCCTACATATATATTCTTGTCTTTTCTGAGTATATTTACTGTGGTCCTTTAGGTTCTTTAGCAAGTAAACTATTTGATAACCCAGATGGATTGTGGATTTTTGAATATTAT'\n",
    "seq_dist = 'TGGATTGTGGATTTTTGAATATTATTTTAAAATAGTACACATACTTAATGTTCATAAGATCATCTTCTTAAATAAAACATGGATGTGTGGGTATGTCTGTACTCCTCCTTTCAGAAAGTGTTTACATATTCTTCATCTACTGTGATTAAGCTCATTGTTGGTTAATTGAAAATATACATGCACATCCATAACTTTTTAAAGAGTA'\n",
    "\n",
    "seq = (seq_prox + 'GATTGTGGATTTTTGAGTATTATTTTATTATTGTTCGCAT' * 30 + seq_dist)\n",
    "\n",
    "print(\"len(seq) = \" + str(len(seq)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Peak positions = [91, 177, 1497]\n",
      "PolyA profile shape = (1610,)\n"
     ]
    }
   ],
   "source": [
    "#Step 1: Detect peaks by scanning APARENT across the sequence (strided).\n",
    "\n",
    "peak_ixs, polya_profile = find_polya_peaks(\n",
    "    aparent_model,\n",
    "    aparent_encoder,\n",
    "    seq,\n",
    "    sequence_stride=5,\n",
    "    conv_smoothing=True,\n",
    "    peak_min_height=0.01,\n",
    "    peak_min_distance=50,\n",
    "    peak_prominence=(0.01, None)\n",
    ")\n",
    "\n",
    "print(\"Peak positions = \" + str(peak_ixs))\n",
    "print(\"PolyA profile shape = \" + str(polya_profile.shape))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Peak PAS scores (log odds) = [2.413, 0.965, 1.688]\n"
     ]
    }
   ],
   "source": [
    "#Step 2: Score the pA signals at the peak intensity positions across the sequence.\n",
    "\n",
    "peak_iso_scores = score_polya_peaks(\n",
    "    aparent_model,\n",
    "    aparent_encoder,\n",
    "    seq,\n",
    "    peak_ixs,\n",
    "    sequence_stride=1,\n",
    "    strided_agg_mode='max',\n",
    "    iso_scoring_mode='both',\n",
    "    score_unit='log'\n",
    ")\n",
    "\n",
    "print(\"Peak PAS scores (log odds) = \" + str(peak_iso_scores))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 864x216 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#Finally plot the predicted pA profile, annotate the peaks and plot the predicted PAS scores.\n",
    "\n",
    "f, ax = plt.subplots(1, 2, figsize=(12, 3), gridspec_kw={'width_ratios': [2, 1]})\n",
    "\n",
    "ax[0].plot(np.arange(len(seq)), polya_profile, linewidth=2, color='orange')\n",
    "\n",
    "for pos in peak_ixs :\n",
    "    ax[0].axvline(x=pos, ymin=0, linewidth=2, color='darkblue', linestyle='--')\n",
    "\n",
    "plt.sca(ax[0])\n",
    "\n",
    "plt.title(\"pA Cleavage Profile\", fontsize=16)\n",
    "\n",
    "#plt.xticks(fontsize=14)\n",
    "peak_names = [str(i) for i in np.arange(len(peak_iso_scores), dtype=np.int)]\n",
    "\n",
    "plt.xticks(peak_ixs, peak_names, fontsize=14)\n",
    "\n",
    "plt.yticks(fontsize=14)\n",
    "\n",
    "plt.xlabel(\"Sequence Position\", fontsize=14)\n",
    "plt.ylabel(\"pA Magnitude\", fontsize=14)\n",
    "\n",
    "ax[1].bar(np.arange(len(peak_iso_scores)), peak_iso_scores, color='deepskyblue', edgecolor='black', linewidth=2)\n",
    "\n",
    "plt.sca(ax[1])\n",
    "\n",
    "plt.title(\"PAS Scores\", fontsize=16)\n",
    "\n",
    "peak_names = [\"Peak \" + str(i) for i in np.arange(len(peak_iso_scores), dtype=np.int)]\n",
    "\n",
    "plt.xticks(np.arange(len(peak_iso_scores)), peak_names, fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "\n",
    "plt.xlabel(\"Detected Peaks\", fontsize=14)\n",
    "plt.ylabel(\"PAS Score\", fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "peak_rows = []\n",
    "peak_info = []\n",
    "for peak_i, peak_ix in enumerate(peak_ixs) :\n",
    "    peak_rows.append(\"Peak \" + str(peak_i))\n",
    "    peak_info.append([\n",
    "        str(peak_ix) + \"nt\",\n",
    "        peak_iso_scores[peak_i]\n",
    "    ])\n",
    "\n",
    "f = plt.figure(figsize=(6, 4))\n",
    "\n",
    "colors = plt.cm.BuPu(np.linspace(0, 0.5, len(peak_rows)))\n",
    "table = plt.table(cellText=peak_info, rowLabels=peak_rows, rowColours=colors, colLabels=[\"Location\", \"PAS Score\"], loc='center')\n",
    "\n",
    "table.set_fontsize(14)\n",
    "table.scale(1.5, 1.5)\n",
    "\n",
    "plt.axis('tight')\n",
    "plt.axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tensorflow]",
   "language": "python",
   "name": "conda-env-tensorflow-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}