{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "291a4e77", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import copy\n", "import requests\n", "import numpy as np\n", "import pandas as pd\n", "from scipy import stats\n", "from pathlib import Path\n", "from tqdm.notebook import trange\n", "from tqdm.notebook import tqdm\n", "from urllib.request import urlretrieve" ] }, { "cell_type": "code", "execution_count": null, "id": "f0bb0447", "metadata": {}, "outputs": [], "source": [ "# git for functions loading and work path finding\n", "import git\n", "\n", "repo = git.Repo('.', search_parent_directories=True)\n", "work_path = Path(repo.working_tree_dir)\n", "if str(work_path) not in sys.path:\n", " sys.path.append(str(work_path))" ] }, { "cell_type": "code", "execution_count": null, "id": "0bdce859", "metadata": {}, "outputs": [], "source": [ "from function.seqfilter import SeqFilter\n", "from function.utilities import seq_aa_check\n", "from function.cutpondr import CutPONDR\n", "from function.ebi import EbiAPI\n", "\n", "# plot\n", "import plotly\n", "import plotly.graph_objects as go\n", "from plotly.graph_objs import Layout\n", "\n", "# attention map decorator\n", "# https://github.com/luo3300612/Visualizer\n", "from visualizer import get_local\n", "get_local.activate()" ] }, { "cell_type": "markdown", "id": "ac9ecfc9", "metadata": {}, "source": [ "# 1. Load pretrain model" ] }, { "cell_type": "code", "execution_count": null, "id": "874c1216", "metadata": {}, "outputs": [], "source": [ "# download pretrained weight from OSF: https://osf.io/jk29b/\n", "pretrain_model_path = work_path / 'trained_weight.pt'\n", "if not pretrain_model_path.is_file():\n", " url = 'https://osf.io/y2jh8/download'\n", " urlretrieve(url, str(pretrain_model_path))" ] }, { "cell_type": "code", "execution_count": null, "id": "0d8d2119", "metadata": {}, "outputs": [], "source": [ "# load model architecture from 2_model_training.ipynb by IPython's magic command\n", "pretrain_ipynb = str(work_path / '2_model_training.ipynb')" ] }, { "cell_type": "code", "execution_count": null, "id": "4b64578e", "metadata": {}, "outputs": [], "source": [ "%run $pretrain_ipynb" ] }, { "cell_type": "code", "execution_count": null, "id": "ae66e8bf", "metadata": {}, "outputs": [], "source": [ "moco = moco_builder.MoCo(base_encoder=AttenTorchScratch, dim=embed_dim, mlp_dim=moco_mlp_dim, T=nce_temp)\n", "moco.load_state_dict(torch.load(pretrain_model_path, map_location=torch.device('cpu')))\n", "base_encoder = copy.deepcopy(moco.base_encoder)\n", "base_encoder = base_encoder.eval()" ] }, { "cell_type": "markdown", "id": "e048b917", "metadata": {}, "source": [ "# 2. Inference" ] }, { "cell_type": "markdown", "id": "4e4f63e8", "metadata": {}, "source": [ "## 2.1 Function for attention map visualization " ] }, { "cell_type": "code", "execution_count": null, "id": "3a7c65bc", "metadata": {}, "outputs": [], "source": [ "# initialize PONDR web crawler and sequence length function\n", "pondr_driver = CutPONDR()\n", "seqfilter = SeqFilter()\n", "\n", "\n", "# attention score normalize function\n", "def minmaxnormalize(data):\n", " return (data - np.min(data)) / (np.max(data) - np.min(data))\n", "\n", "\n", "# exception for no disorder region\n", "class NoDisorderRegion(Exception):\n", " pass\n", "\n", "\n", "# attention map visualization\n", "class SeqEncodeForPlot():\n", "\n", " def __init__(self, input_type, sequence):\n", " '''\n", " input_type = ['uniprot','custom'] \n", " sequence = ['Q13148','AAAAAAAA']\n", " '''\n", "\n", " # get seq frag\n", " if input_type == 'uniprot':\n", " frag_out = self.__get_vis_seq_uid(sequence)\n", " elif input_type == 'custom':\n", " frag_out = self.__get_vis_seq_custom(sequence)\n", " self.title = frag_out['title']\n", " self.frag_seq = frag_out['frag_seq']\n", " self.seq_length = len(frag_out['frag_seq'])\n", " self.start_mod = frag_out['start_mod']\n", " self.end_mod = frag_out['end_mod']\n", "\n", " # get plot label\n", " self.plot_label = self.__get_plot_label(self.frag_seq, self.start_mod, self.end_mod)\n", "\n", " # encode\n", " self.encode_seq = seqprocess.seq_process_pipe([self.frag_seq])\n", "\n", " # get atten\n", " self.atten = self.__get_atten_from_decorator(self.encode_seq)\n", "\n", " # post process 1: get last cls atten mean\n", " self.last_cls_atten_mean = self.__get_last_cls_atten_mean(self.atten)\n", "\n", " # post process 2: cut cls\n", " self.last_cls_atten_mean_cut = self.__cut_cls(self.last_cls_atten_mean, self.seq_length)\n", "\n", " # post process 3: normalize\n", " self.last_cls_atten_mean_cut_normalize = self.__normalize_cls(self.last_cls_atten_mean_cut)\n", "\n", " # post process 4: feature sep\n", " sep_out = self.__get_feature_sep(self.frag_seq, self.last_cls_atten_mean_cut_normalize)\n", " self.sep_cls_for_plot = sep_out['sep_cls_for_plot']\n", " self.sep_feature_label = sep_out['sep_feature_label']\n", "\n", "###########get seq###########\n", " def __get_vis_seq_custom(self, custom_sequence):\n", " start_mod = 1\n", " end_mod = len(custom_sequence)\n", "\n", " return {\n", " \"title\": '{}, {}~{}'.format('custom', start_mod, end_mod),\n", " \"frag_seq\": custom_sequence,\n", " \"start_mod\": start_mod,\n", " \"end_mod\": end_mod\n", " }\n", "\n", " def __get_vis_seq_uid(self, uniprot_id):\n", " # retriving sequence\n", " seq_info = self.__get_sequence_online(uniprot_id, \"VSL2\", pondr_driver)\n", " gene_name = seq_info['gene_name']\n", " od_ident = seq_info['od_ident']\n", " protein_sequence = seq_info['protein_sequence']\n", "\n", " # od_ident length filter\n", " od_ident = seqfilter.length_filter_by_od_ident(od_ident, disorder_filter_length=40, order_filter_length=10)\n", " od_index = seqfilter.get_od_index(od_ident)['disorder_region']\n", "\n", " if len(od_index) == 0:\n", " raise NoDisorderRegion(\"This protein does not have disorder region\")\n", " elif len(od_index) == 1:\n", " print(\"Only 1 disorder region {}, automatically use that\".format(od_index[0]))\n", " od_index_i = 0\n", " else:\n", " choose_hint = ''\n", " for index, element in enumerate(od_index):\n", " choose_hint = choose_hint + str(\"region {}: {}\".format(index, element)) + \"\\n\" + ' '\n", " od_index_i = int(\n", " input(\"Please choose disorder region: \\n {}\".format(choose_hint)))\n", "\n", " # get protein seq by chosen index\n", " start = od_index[od_index_i]['start']\n", " end = od_index[od_index_i]['end']\n", "\n", " if (end - start) > 512:\n", " print(\"sequence length longer than 512: please specify the start and end\")\n", " start = int(input(\"start: \\n\"))\n", " end = int(input(\"end: \\n\"))\n", "\n", " frag_seq = protein_sequence[start:end]\n", "\n", " # mod start index\n", " start_mod = start + 1\n", " end_mod = end\n", "\n", " return {\n", " \"title\":'{}, {}, {}~{}'.format(uniprot_id, gene_name, start_mod, end_mod),\n", " \"frag_seq\":frag_seq,\n", " \"start_mod\":start_mod,\n", " \"end_mod\":end_mod\n", " }\n", "\n", " def __get_sequence_online(self, uniprot_id, od_ident_algorithm, pondr_driver):\n", " # get seqeucne by uniprot id\n", " a_protein = EbiAPI(uniprot_id)\n", "\n", " # get sequence\n", " protein_sequence = a_protein.protein_sequence\n", " protein_sequence = seq_aa_check(protein_sequence)\n", "\n", " gene_name = a_protein.gene_name\n", " print('protein: {}, gene name: {}, length: {}'.format(uniprot_id, gene_name, len(protein_sequence)))\n", "\n", " # use pondr to get od_ident\n", " print(\"sending to PONDR by algorithm {}...\".format(od_ident_algorithm))\n", " pondr_driver.cut(protein_sequence, protein_name='aa', algorithm=od_ident_algorithm)\n", " od_ident = pondr_driver.get_od_ident()\n", "\n", " return {\n", " \"uniprot_id\": uniprot_id,\n", " \"gene_name\": gene_name,\n", " \"protein_sequence\": protein_sequence,\n", " \"od_ident\": od_ident\n", " }\n", "###########get seq###########\n", "\n", " # lable plot\n", " def __get_plot_label(self, frag_seq, start_mod, end_mod):\n", " label = []\n", " for index, element in enumerate(frag_seq):\n", " label.append(\"{}_{}\".format(index + start_mod, element))\n", " \n", " return label\n", "\n", " # get atten\n", " def __get_atten_from_decorator(self, encoded_seq):\n", " get_local.clear()\n", " _ = base_encoder(encoded_seq, return_all_atten=True)[1]\n", "\n", " # tidy attention_map from decorator to (layer, head, length, length)\n", " atten = get_local.cache\n", " atten = np.stack(atten['scaled_dot_product_attention']).squeeze()\n", " atten = atten.reshape([num_heads, depth, atten.shape[-2], atten.shape[-2]]) # head, layer, length, length\n", " atten = np.moveaxis(atten, [0, 1], [1, 0]) # layer, head, length, length\n", " \n", " return torch.tensor(atten)\n", "\n", "###########post precess###########\n", " # get last cls atten\n", " def __get_last_cls_atten_mean(self, atten):\n", " last_cls_atten = atten[-1, :, 0, 1:]\n", " last_cls_atten_mean = last_cls_atten.mean(dim=0, keepdim=False)\n", " last_cls_atten_mean = last_cls_atten_mean.detach().numpy()\n", " \n", " return last_cls_atten_mean\n", "\n", " # cut cls\n", " def __cut_cls(self, last_cls_atten_mean, seq_length):\n", " last_cls_atten_mean = last_cls_atten_mean[:seq_length]\n", " \n", " return last_cls_atten_mean\n", "\n", " # normalize\n", " def __normalize_cls(self, last_cls_atten_mean):\n", " last_cls_atten_mean = minmaxnormalize(last_cls_atten_mean) #change normalize way\n", " \n", " return last_cls_atten_mean\n", "\n", " # feature sep\n", " def __get_feature_sep(self, frag_seq, last_cls_atten_mean):\n", " # sep feature label\n", " sep_feature_label = [\n", " \"Hydrophobic (A I L M P V)\", \n", " 'C', \n", " 'G', \n", " 'S T', \n", " 'Prion like (Q N)',\n", " 'Negative charge (D E)', \n", " 'Positive charge (R K H)',\n", " 'Aromatic (W Y F)', \n", " 'All features'\n", " ]\n", "\n", " # many conditions\n", " frag_seq = list(frag_seq)\n", " df = pd.DataFrame(frag_seq, columns=['frag_seq'])\n", " df['wyf'] = df['frag_seq'].apply(lambda x: x in ['W', 'Y', 'F'])\n", " df['rkh'] = df['frag_seq'].apply(lambda x: x in ['R', 'K', 'H'])\n", " df['de'] = df['frag_seq'].apply(lambda x: x in ['D', 'E'])\n", " df['qn'] = df['frag_seq'].apply(lambda x: x in ['Q', 'N'])\n", " df['st'] = df['frag_seq'].apply(lambda x: x in ['S', 'T'])\n", " df['g'] = df['frag_seq'].apply(lambda x: x in ['G'])\n", " df['c'] = df['frag_seq'].apply(lambda x: x in ['C'])\n", " df['lpavmi'] = df['frag_seq'].apply(lambda x: x in ['L', 'P', 'A', 'V', 'M', 'I'])\n", "\n", " all_cond_array = []\n", " all_cond_array.append(np.where(df['lpavmi'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['c'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['g'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['st'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['qn'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['de'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['rkh'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(np.where(df['wyf'], last_cls_atten_mean, np.nan))\n", " all_cond_array.append(last_cls_atten_mean)\n", "\n", " all_cond_array = np.stack(all_cond_array)\n", "\n", " return {\n", " \"sep_cls_for_plot\": all_cond_array,\n", " \"sep_feature_label\": sep_feature_label\n", " }\n", "###########post precess###########" ] }, { "cell_type": "markdown", "id": "25095e7d", "metadata": {}, "source": [ "## 2.2 Get attention map" ] }, { "cell_type": "code", "execution_count": null, "id": "22e8fbda", "metadata": {}, "outputs": [], "source": [ "# usage case1: directly input custom sequence\n", "# seqinfo = SeqEncodeForPlot(\"custom\",\"REPNQAFGSGNNSYSGSNSGAAIGWGSASNAGSGSGFNGGFGSSMDSKSSGWGM\")\n", "\n", "# usage case2: by entering Uniprot Entry ID, \n", "# sending to PONDR by VSL2 to get disorder regions which fit the length criteria (>=40)\n", "# seqinfo = SeqEncodeForPlot(\"uniprot\", \"Q13148\")" ] }, { "cell_type": "markdown", "id": "e944b3e6", "metadata": {}, "source": [ "# 3. Plot feature maps" ] }, { "cell_type": "code", "execution_count": null, "id": "32bd877f", "metadata": {}, "outputs": [], "source": [ "def plot(seqinfo):\n", " custom_color_scale = ['rgb(220,220,220)','rgb(255,243,59)','rgb(253,199,12)','rgb(243,144,63)','rgb(237,104,60)','rgb(233,62,58)']\n", " fig = go.Figure(data=[go.Heatmap(hovertemplate='head: %{y}<br>aa: %{x}<br>value: %{z}<extra></extra>',\n", " z=seqinfo.sep_cls_for_plot, \n", " x=seqinfo.plot_label,\n", " ygap = 1,\n", " y=seqinfo.sep_feature_label,\n", " colorscale=custom_color_scale,\n", " zmin=0,zmax=1 #\n", " )],\n", "# layout = Layout(paper_bgcolor='rgba(255,255,255,1)',plot_bgcolor='rgba(255,255,255,1)') #for image\n", " layout = Layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)') #for html\n", " )\n", "\n", " # layout\n", " fig.update_layout(\n", " title=seqinfo.title,\n", " width=seqinfo.seq_length*10,\n", " height=450,\n", " yaxis_showticklabels=True,\n", " yaxis = dict(tickfont=dict(size=12, color='black')),\n", " xaxis_showticklabels=True,\n", "\n", " xaxis_tickmode='linear',\n", " font=dict(\n", " size=8\n", " ))\n", "\n", " #hover method\n", " fig.update_layout(hovermode='x unified')\n", "\n", " #no scale bar\n", " fig.update_traces(showscale=False)\n", "\n", "\n", " fig.update_layout(shapes=[\n", " dict(type= 'line',\n", " yref= 'y', y0= i, y1= i,\n", " xref= 'x', x0= -0.5, x1= len(seqinfo.plot_label),\n", " line=dict(\n", " color='rgb(200, 200, 200)',\n", " width=0.5,\n", " dash=\"dash\"\n", " )\n", " ) for i in [0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5] ])\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "76016514", "metadata": {}, "outputs": [], "source": [ "# plot(seqinfo)" ] }, { "cell_type": "code", "execution_count": null, "id": "d75d2bca", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }