{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "user = 'userName'\n", "model_name = 'metal_prediction_CNN'" ] }, { "cell_type": "raw", "metadata": { "collapsed": true }, "source": [ "# Removing invalid samples - already done\n", "\n", "duplicate_dict = {}\n", "rows_to_delete = []\n", "count = 0\n", "for i in range(seqs.shape[0]):\n", " if 'X' in seqs[i] \\\n", " or 'U' in seqs[i] \\\n", " or '3CO' in target[i]\\\n", " or '3NI' in target[i] \\\n", " or 'FE2'in target[i] \\\n", " or 'CU1'in target[i]\\\n", " or 'MN3' in target[i] \\\n", " or np.isnan(cluster_numbers[i]):\n", " rows_to_delete.append(i)\n", " count +=1\n", " elif seqs[i] not in duplicate_dict.keys():\n", " duplicate_dict[seqs[i]] = target[i]\n", "\n", " else:\n", " if target[i] != duplicate_dict[seqs[i]]:\n", " rows_to_delete.append(i)\n", " count +=1\n", " \n", "# df = df.drop(df.index[rows_to_delete])\n", "# df.to_parquet('Metal_all_20180601.parquet')\n", "seqs = np.delete(seqs, rows_to_delete, 0)\n", "target = np.delete(target, rows_to_delete)\n", "cluster_numbers = np.delete(cluster_numbers, rows_to_delete)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import sys\n", "import json\n", "\n", "print (\"Initializing global variables...\", end=' ')\n", "sys.stdout.flush()\n", "\n", "# Filepaths\n", "output_file = './logs/results.txt'\n", "hist_path = model_path = fig_path = './logs/'\n", "dict_path = './dictionaries/'\n", "\n", "print (\"Done\")\n", "print (\" Filepath set to ./logs/\")\n", "\n", "##################################################\n", "\n", "print (\"Importing modules...\", end=' ')\n", "import modules\n", "print (\"Done\")\n", "\n", "##################################################\n", "\n", "print (\"Reading data from disk...\", end=' ')\n", "sys.stdout.flush()\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "df = pd.read_parquet('./datasets/Metal_all_20180601.parquet')\n", "seqs = np.array(df.sequence)\n", "target = np.array(df.ligandId)\n", "cluster_numbers = np.array(df.clusterNumber90)\n", "\n", "for i in range(target.shape[0]):\n", " target[i] = [label_dict[target[i]]]\n", "\n", "print (\"Done\")\n", "\n", "##################################################\n", "\n", "print (\"Loading dictionaries...\", end=' ')\n", "sys.stdout.flush()\n", "\n", "# FOFE\n", "vocab_dic_fofe = {}\n", "with open(dict_path + \"vocab_dict_fofe\", 'r') as fp:\n", " vocab_dic_fofe = json.load(fp)\n", "\n", "print (\"Done\")\n", "\n", "##################################################\n", "\n", "print (\"Performing cross validation split...\", end=' ')\n", "ratio = 0.9\n", "split = int(ratio*len(seqs))\n", "train_seqs, val_seqs = seqs[:split], seqs[split:]\n", "train_label, val_label = target[:split], target[split:]\n", "print (\"Done\")\n", "print (\" Ratio :\", ratio)\n", "print (\" Train_range :\", 0, \"-\", split-1)\n", "print (\" Val_range :\", split, \"-\", len(seqs)-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "df.groupby('ligandId').count()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Generator\n", "- FOFE Encoding" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "label_dict = {}\n", "with open(dict_path + \"metal_dict\", 'r') as fp:\n", " label_dict = json.load(fp)\n", " \n", "train_args = {'sequences': train_seqs,\n", " 'labels': train_label,\n", " 'translator': vocab_dic_fofe}\n", "val_args = {'sequences': val_seqs,\n", " 'labels': val_label,\n", " 'translator': vocab_dic_fofe}\n", "common_args = {'batch_size': 100,\n", " 'input_shape': (800,),\n", " 'label_shape': (8, ),\n", " 'shuffle': True}\n", "\n", "train_gen = modules.FOFEGenerator(**train_args, **common_args)\n", "val_gen = modules.FOFEGenerator(**val_args, **common_args)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model\n", "- CNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "scrolled": true }, "outputs": [], "source": [ "# ProtVec:100, One-hot:20, blosum62:20, property:7\n", "dimension = 800\n", "cutoff = 8\n", "\n", "import tensorflow as tf\n", "import time\n", "import matplotlib.pyplot as plt\n", "% matplotlib inline\n", "np.random.seed(2017) \n", "from keras.models import Sequential, Model\n", "from keras.layers.convolutional import Conv2D, MaxPooling2D, Convolution1D, MaxPooling1D, AveragePooling2D\n", "from keras.layers import Activation, Flatten, Dense, Dropout, Reshape, Embedding, Input\n", "from keras.layers.normalization import BatchNormalization\n", "from keras.utils import np_utils\n", "from keras.optimizers import SGD\n", "import numpy as np\n", "import keras\n", "from keras.models import Model, load_model\n", "from keras.optimizers import Adam, SGD, RMSprop\n", "# Visualization\n", "from keras.utils import plot_model\n", "\n", "input_shape = (dimension,)\n", "\n", "input_0 = Input(shape=input_shape, dtype='float32')\n", "input_0_reshape = Reshape((1,dimension,1), input_shape=(dimension,))(input_0)\n", "conv2d_3 = Conv2D(2, (1, 3), padding='same')(input_0_reshape)\n", "conv2d_5 = Conv2D(2, (1, 5), padding='same')(input_0_reshape)\n", "conv2d_7 = Conv2D(2, (1, 7), padding='same')(input_0_reshape)\n", "\n", "x = keras.layers.concatenate([conv2d_3,conv2d_5,conv2d_7])\n", "x = Activation('relu')(x)\n", "x = Flatten()(x)\n", "x = Dense(cutoff, activation='relu')(x)\n", "output_0 = Dense(cutoff, activation='softmax')(x)\n", "\n", "model = Model(inputs=input_0, outputs=output_0) \n", "# end of the MODEL\n", "\n", "sgd = SGD(lr = 0.01, momentum = 0.9, decay = 0, nesterov = False)\n", "model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n", "# model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "scrolled": false }, "outputs": [], "source": [ "model_args = {'model': model, \n", " 'generators': [train_gen, val_gen], \n", " 'callbacks': [], \n", " 'post_train_args': {'user': user, \n", " 'model': model_name, \n", " 'result': output_file, \n", " 'fig_path': fig_path}}\n", "\n", "trainer = modules.Trainer(**model_args)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import warnings; \n", "warnings.simplefilter('ignore')\n", "\n", "trainer.start(epoch=15)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# serialize model to JSON\n", "model_json = model.to_json()\n", "with open(\"./models/metal_predict.json\", \"w\") as json_file:\n", " json_file.write(model_json)\n", "# serialize weights to HDF5\n", "model.save_weights(\"./models/metal_predict.h5\")\n", "print(\"Saved model to disk\")" ] }, { "cell_type": "raw", "metadata": { "collapsed": true }, "source": [ "from keras.models import model_from_json\n", "# load json and create model\n", "json_file = open('./models/metal_predict.json', 'r')\n", "loaded_model_json = json_file.read()\n", "json_file.close()\n", "model = model_from_json(loaded_model_json)\n", "# load weights into new model\n", "model.load_weights(\"./models/metal_predict.h5\")\n", "print(\"Loaded model from disk\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }