{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "rotary-animation",
"metadata": {},
"outputs": [],
"source": [
"from itertools import combinations\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scikit_posthocs as sp\n",
"import seaborn as sns\n",
"from bootstrap import bootstrap_error_estimate\n",
"from scipy.stats import gmean\n",
"# Model comparison imports\n",
"from delong_ci import calc_auc_ci\n",
"from mlxtend.evaluate import cochrans_q, mcnemar, mcnemar_table\n",
"from mlxtend.evaluate import paired_ttest_5x2cv\n",
"#RDKit imports\n",
"from rdkit import Chem\n",
"from rdkit import Chem\n",
"from rdkit.Chem import AllChem\n",
"# ML imports\n",
"from lightgbm import LGBMClassifier\n",
"from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import roc_auc_score, roc_curve\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from tqdm.notebook import tqdm\n",
"import chembl_downloader\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib_inline\n",
"matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n",
"sns.set_style('whitegrid')\n",
"sns.set_context('talk')"
]
},
{
"cell_type": "markdown",
"id": "appointed-basis",
"metadata": {},
"source": [
"Query the ChEMBL database for hERG data. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "referenced-momentum",
"metadata": {},
"outputs": [],
"source": [
"query = \"\"\"select canonical_smiles, cs.molregno, md.chembl_id as mol_chembl_id, standard_relation, standard_value,\n",
"standard_type, standard_units, description, td.organism, assay_type, confidence_score,\n",
"td.pref_name, td.chembl_id as tgt_chembl_id\n",
"from activities act\n",
"join assays ass on act.assay_id = ass.assay_id\n",
"join target_dictionary td on td.tid = ass.tid\n",
"join compound_structures cs on cs.molregno = act.molregno\n",
"join molecule_dictionary md on md.molregno = cs.molregno\n",
"where ass.tid = 165\n",
"and assay_type in ('B','F')\n",
"and standard_value is not null\n",
"and standard_units = 'nM'\n",
"and act.standard_relation is not null\n",
"and standard_type = 'IC50'\n",
"and standard_relation = '='\"\"\""
]
},
{
"cell_type": "markdown",
"id": "exterior-denmark",
"metadata": {},
"source": [
"Make a reproducible query ChEMBL database."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "retired-accounting",
"metadata": {},
"outputs": [],
"source": [
"df_ok = chembl_downloader.query(query)"
]
},
{
"cell_type": "markdown",
"id": "finnish-paradise",
"metadata": {},
"source": [
"A quick sanity check to ensure that we extracted the data correctly. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "recent-bunny",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" canonical_smiles | \n",
" molregno | \n",
" mol_chembl_id | \n",
" standard_relation | \n",
" standard_value | \n",
" standard_type | \n",
" standard_units | \n",
" description | \n",
" organism | \n",
" assay_type | \n",
" confidence_score | \n",
" pref_name | \n",
" tgt_chembl_id | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" O=C(CCCN1CC=C(n2c(=O)[nH]c3ccccc32)CC1)c1ccc(F... | \n",
" 152751 | \n",
" CHEMBL1108 | \n",
" = | \n",
" 32.2 | \n",
" IC50 | \n",
" nM | \n",
" K+ channel blocking activity in human embryoni... | \n",
" Homo sapiens | \n",
" F | \n",
" 9 | \n",
" HERG | \n",
" CHEMBL240 | \n",
"
\n",
" \n",
" 1 | \n",
" O=C(O[C@@H]1C[C@@H]2C[C@H]3C[C@H](C1)N2CC3=O)c... | \n",
" 1543376 | \n",
" CHEMBL2368925 | \n",
" = | \n",
" 5950.0 | \n",
" IC50 | \n",
" nM | \n",
" K+ channel blocking activity in human embryoni... | \n",
" Homo sapiens | \n",
" F | \n",
" 9 | \n",
" HERG | \n",
" CHEMBL240 | \n",
"
\n",
" \n",
" 2 | \n",
" COc1ccc(CCN(C)CCCC(C#N)(c2ccc(OC)c(OC)c2)C(C)C... | \n",
" 1219 | \n",
" CHEMBL6966 | \n",
" = | \n",
" 143.0 | \n",
" IC50 | \n",
" nM | \n",
" K+ channel blocking activity in human embryoni... | \n",
" Homo sapiens | \n",
" F | \n",
" 9 | \n",
" HERG | \n",
" CHEMBL240 | \n",
"
\n",
" \n",
" 3 | \n",
" CCCCN(CCCC)CCC(O)c1cc2c(Cl)cc(Cl)cc2c2cc(C(F)(... | \n",
" 152728 | \n",
" CHEMBL1107 | \n",
" = | \n",
" 196.0 | \n",
" IC50 | \n",
" nM | \n",
" K+ channel blocking activity in Chinese hamste... | \n",
" Homo sapiens | \n",
" F | \n",
" 8 | \n",
" HERG | \n",
" CHEMBL240 | \n",
"
\n",
" \n",
" 4 | \n",
" CCOC(=O)N1CCC(=C2c3ccc(Cl)cc3CCc3cccnc32)CC1 | \n",
" 110803 | \n",
" CHEMBL998 | \n",
" = | \n",
" 173.0 | \n",
" IC50 | \n",
" nM | \n",
" K+ channel blocking activity in human embryoni... | \n",
" Homo sapiens | \n",
" F | \n",
" 9 | \n",
" HERG | \n",
" CHEMBL240 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" canonical_smiles molregno mol_chembl_id \\\n",
"0 O=C(CCCN1CC=C(n2c(=O)[nH]c3ccccc32)CC1)c1ccc(F... 152751 CHEMBL1108 \n",
"1 O=C(O[C@@H]1C[C@@H]2C[C@H]3C[C@H](C1)N2CC3=O)c... 1543376 CHEMBL2368925 \n",
"2 COc1ccc(CCN(C)CCCC(C#N)(c2ccc(OC)c(OC)c2)C(C)C... 1219 CHEMBL6966 \n",
"3 CCCCN(CCCC)CCC(O)c1cc2c(Cl)cc(Cl)cc2c2cc(C(F)(... 152728 CHEMBL1107 \n",
"4 CCOC(=O)N1CCC(=C2c3ccc(Cl)cc3CCc3cccnc32)CC1 110803 CHEMBL998 \n",
"\n",
" standard_relation standard_value standard_type standard_units \\\n",
"0 = 32.2 IC50 nM \n",
"1 = 5950.0 IC50 nM \n",
"2 = 143.0 IC50 nM \n",
"3 = 196.0 IC50 nM \n",
"4 = 173.0 IC50 nM \n",
"\n",
" description organism assay_type \\\n",
"0 K+ channel blocking activity in human embryoni... Homo sapiens F \n",
"1 K+ channel blocking activity in human embryoni... Homo sapiens F \n",
"2 K+ channel blocking activity in human embryoni... Homo sapiens F \n",
"3 K+ channel blocking activity in Chinese hamste... Homo sapiens F \n",
"4 K+ channel blocking activity in human embryoni... Homo sapiens F \n",
"\n",
" confidence_score pref_name tgt_chembl_id \n",
"0 9 HERG CHEMBL240 \n",
"1 9 HERG CHEMBL240 \n",
"2 9 HERG CHEMBL240 \n",
"3 8 HERG CHEMBL240 \n",
"4 9 HERG CHEMBL240 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_ok.head()"
]
},
{
"cell_type": "markdown",
"id": "french-caution",
"metadata": {},
"source": [
"Aggregate the results by taking the geometric mean of replicates. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eligible-pressing",
"metadata": {},
"outputs": [],
"source": [
"grouper = df_ok.groupby([\"canonical_smiles\",\"molregno\"])\n",
"data_df = grouper['standard_value'].apply(gmean).to_frame(name = 'IC50').reset_index()"
]
},
{
"cell_type": "markdown",
"id": "electoral-fossil",
"metadata": {},
"source": [
"Add a new column with the pIC50"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "satisfied-hamburg",
"metadata": {},
"outputs": [],
"source": [
"data_df['pIC50'] = -np.log10(data_df.IC50*1e-9)"
]
},
{
"cell_type": "markdown",
"id": "miniature-advocate",
"metadata": {},
"source": [
"Set the \"Active\" field to 1 if the pIC50 >= 5 (10uM), otherwise 0"
]
},
{
"cell_type": "markdown",
"id": "exciting-supervisor",
"metadata": {},
"source": [
"Look at counts of active and inactive molecules"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "abroad-measure",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1 4601\n",
"0 2274\n",
"Name: Active, dtype: int64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_df['Active'] = [1 if x >= 5 else 0 for x in data_df.pIC50]\n",
"data_df['Active'].value_counts()"
]
},
{
"cell_type": "markdown",
"id": "geographic-annex",
"metadata": {},
"source": [
"Visualize the activity distribution."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "completed-influence",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.displot(data=data_df,x='pIC50',hue='Active')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "geological-agent",
"metadata": {},
"source": [
"Define a function to get a Morgan fingerprint from a SMILES string"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "lucky-former",
"metadata": {},
"outputs": [],
"source": [
"def gen_fp(smi):\n",
" mol = Chem.MolFromSmiles(smi)\n",
" fp = None\n",
" if mol:\n",
" fp = AllChem.GetMorganFingerprintAsBitVect(mol,2)\n",
" return fp"
]
},
{
"cell_type": "markdown",
"id": "fuzzy-study",
"metadata": {},
"source": [
"Enable the \"progress_apply\" function that lets us use a progress bar."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "further-drama",
"metadata": {},
"outputs": [],
"source": [
"tqdm.pandas()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "automotive-simulation",
"metadata": {},
"outputs": [],
"source": [
"data_df['fp'] = data_df.canonical_smiles.apply(gen_fp)"
]
},
{
"cell_type": "markdown",
"id": "variable-cherry",
"metadata": {},
"source": [
"Remove rows where we didn't successfully generate a fingerprint."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "indian-intranet",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6875, 6875)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_orig_rows = len(data_df)\n",
"data_df.dropna(inplace=True)\n",
"num_filtered_rows = len(data_df)\n",
"num_orig_rows, num_filtered_rows"
]
},
{
"cell_type": "markdown",
"id": "sonic-elder",
"metadata": {},
"source": [
"Build a quick ML model as a test. "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "becoming-eight",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7055758053840365"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train, test = train_test_split(data_df)\n",
"train_X = np.stack(train.fp)\n",
"train_y = train.Active.values\n",
"test_X = np.stack(test.fp)\n",
"test_Y = test.Active.values\n",
"lgbm = LGBMClassifier()\n",
"lgbm.fit(train_X,train_y)\n",
"pred = lgbm.predict(test_X)\n",
"auc = roc_auc_score(test_Y,pred)\n",
"auc"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "turkish-charge",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prob = lgbm.predict_proba(test_X)[:,1]\n",
"false_pos_rate, true_pos_rate, thresholds = roc_curve(test_Y,prob)\n",
"ax = sns.lineplot(x=false_pos_rate,y=true_pos_rate)\n",
"ax.set_xlabel(\"True Positive Rate\")\n",
"ax.set_ylabel(\"False Positive Rate\")\n",
"# add the unity line\n",
"linemin = 0\n",
"linemax = 1\n",
"ax.plot([linemin,linemax],[linemin,linemax],color=\"grey\",linewidth=2,linestyle=\"--\");"
]
},
{
"cell_type": "markdown",
"id": "chronic-filename",
"metadata": {},
"source": [
"For 10 folds of cross-validation loop over the different model types, train and test models. "
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "finnish-venice",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b43b2dcfed845f786c421ccb6a42a42",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"method_list = [KNeighborsClassifier, RandomForestClassifier, LGBMClassifier]\n",
"method_name_list = [x.__name__.replace(\"Classifier\",\"\") for x in method_list]\n",
"truth_list = []\n",
"pred_list = []\n",
"prob_list = []\n",
"cv_cycles = 10\n",
"for i in tqdm(range(0,cv_cycles)):\n",
" train, test = train_test_split(data_df)\n",
" train_X = np.stack(train.fp)\n",
" train_y = train.Active.values\n",
" test_X = np.stack(test.fp)\n",
" test_y = test.Active.values\n",
" cycle_pred = []\n",
" cycle_prob = []\n",
" for method, method_name in zip(method_list, method_name_list):\n",
" if method_name == \"XGB\":\n",
" cls = method(use_label_encoder=False, eval_metric='logloss', n_jobs=-1)\n",
" else:\n",
" cls = method(n_jobs=-1)\n",
" cls.fit(train_X, train_y)\n",
" cycle_pred.append(cls.predict(test_X))\n",
" cycle_prob.append(cls.predict_proba(test_X))\n",
"\n",
" truth_list.append(test.Active.values) \n",
" pred_list.append(cycle_pred)\n",
" prob_list.append(cycle_prob)"
]
},
{
"cell_type": "markdown",
"id": "broadband-waste",
"metadata": {},
"source": [
"Build a dataframe with the AUC values collected above. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "instant-people",
"metadata": {},
"outputs": [],
"source": [
"auc_result = []\n",
"for truth, prob in zip(truth_list,prob_list):\n",
" for name, p in zip(method_name_list, prob):\n",
" auc_result.append([name,roc_auc_score(truth,p[:,1])])\n",
"auc_df = pd.DataFrame(auc_result,columns=[\"Method\",\"AUC\"])"
]
},
{
"cell_type": "markdown",
"id": "concrete-oklahoma",
"metadata": {},
"source": [
"Here's what people typically do in the literature. Construct a bar char with the mean AUC for each of the methods, add a \"whisker\" to show the standard deviation. **In my mind this is not a good way to present data**. It doesn't adequately reflect the performance across cross validation folds and doesn't show whether the differences between methods are statistically significant. "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "popular-radical",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax = sns.barplot(x=\"Method\",y=\"AUC\",data=auc_df)\n",
"labels = [x.get_text() for x in ax.get_xticklabels()]\n",
"ax.set(xticklabels=labels)\n",
"ax.set(ylim=[0,1])\n",
"ax.set(xlabel=\"\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "alone-shadow",
"metadata": {},
"source": [
"Here's a somewhat better approach where we represent the distribution of AUC values as box plots. This is somewhat better, but we're still not making an adequate comparison between methods. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "charming-transformation",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.boxplot(x=\"Method\",y=\"AUC\",data=auc_df)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "opposite-double",
"metadata": {},
"source": [
"We can calculate a 95% confidence interval around each AUC using [DeLong's method](https://github.com/yandexdataschool/roc_comparison/blob/master/compare_auc_delong_xu.py). "
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "macro-november",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Cycle | \n",
" Method | \n",
" AUC | \n",
" LB | \n",
" UB | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" KNeighbors | \n",
" 0.788374 | \n",
" 0.766868 | \n",
" 0.809879 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" RandomForest | \n",
" 0.820932 | \n",
" 0.800811 | \n",
" 0.841053 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" LGBM | \n",
" 0.814247 | \n",
" 0.793943 | \n",
" 0.834551 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" KNeighbors | \n",
" 0.803592 | \n",
" 0.782375 | \n",
" 0.824808 | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" RandomForest | \n",
" 0.830648 | \n",
" 0.810878 | \n",
" 0.850418 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Cycle Method AUC LB UB\n",
"0 0 KNeighbors 0.788374 0.766868 0.809879\n",
"1 0 RandomForest 0.820932 0.800811 0.841053\n",
"2 0 LGBM 0.814247 0.793943 0.834551\n",
"3 1 KNeighbors 0.803592 0.782375 0.824808\n",
"4 1 RandomForest 0.830648 0.810878 0.850418"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auc_result = []\n",
"for cycle, [truth, prob] in enumerate(zip(truth_list,prob_list)):\n",
" for name, p in zip(method_name_list, prob):\n",
" truth = np.array([int(x) for x in truth])\n",
" auc, (lb, ub) = calc_auc_ci(truth,p[:,1])\n",
" auc_result.append([cycle,name, auc, lb, ub])\n",
"auc_ci_df = pd.DataFrame(auc_result,columns=[\"Cycle\",\"Method\",\"AUC\",\"LB\",\"UB\"])\n",
"auc_ci_df.head()"
]
},
{
"cell_type": "markdown",
"id": "south-element",
"metadata": {},
"source": [
"Define a routine for displaying the AUC values and the associated 95% confidence intervals."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "tired-wheat",
"metadata": {},
"outputs": [],
"source": [
"def ci_pointplot(input_df, x_col=\"Cycle\", y_col=\"AUC\", hue_col=\"Method\", lb_col=\"LB\", ub_col=\"UB\"):\n",
" dodge_val = 0.25\n",
" palette_name = \"deep\"\n",
" cv_cycles = len(input_df[x_col].unique())\n",
" fig, ax = plt.subplots(1,1,figsize=(10, 5))\n",
" g = sns.pointplot(\n",
" x=x_col, y=y_col, hue=hue_col, data=input_df, dodge=dodge_val, join=False, palettte=palette_name, ax=ax)\n",
" colors = sns.color_palette(palette_name, len(input_df.Method.unique())) * cv_cycles\n",
" ax.axvline(0.5, ls=\"--\", c=\"gray\")\n",
" for x in np.arange(0.5, cv_cycles, 1):\n",
" ax.axvline(x, ls=\"--\", c=\"gray\")\n",
" y_val = input_df[y_col]\n",
" lb = y_val - input_df[lb_col]\n",
" ub = input_df[ub_col] - y_val\n",
" x_pos = []\n",
" for i in range(0, cv_cycles):\n",
" x_pos += [i - dodge_val / 2, i, i + dodge_val / 2]\n",
" _ = ax.errorbar(x_pos, y_val, yerr=[lb, ub], fmt=\"none\", capsize=0, ecolor=colors)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "pharmaceutical-commerce",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"