{ "cells": [ { "cell_type": "markdown", "id": "90a6bd76-386e-484b-8a18-27a0fec48ca6", "metadata": {}, "source": [ "# Fit model to data\n", "We will fit a `Polyclonal` model to the RBD antibody mix we simulated.\n", "\n", "First, we read in that simulated data.\n", "Recall that we simulated both \"exact\" and \"noisy\" data, with several average per-library mutations rates, and at six different concentrations.\n", "Here we analyze the noisy data for the library with an average of 2 mutations per gene, measured at three different concentrations, as this represents a fairly realistic representation of a real experiment:" ] }, { "cell_type": "code", "execution_count": 1, "id": "18a5882b-3a28-4fdb-948c-2ab8dc9b82e7", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:44:58.162010Z", "iopub.status.busy": "2021-11-19T22:44:58.161700Z", "iopub.status.idle": "2021-11-19T22:44:59.876110Z", "shell.execute_reply": "2021-11-19T22:44:59.875264Z", "shell.execute_reply.started": "2021-11-19T22:44:58.161877Z" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
libraryaa_substitutionsconcentrationprob_escapeIC90
0avg2muts0.250.0504400.1128
1avg2muts0.250.1431000.1128
2avg2muts0.250.0545200.1128
3avg2muts0.250.0847300.1128
4avg2muts0.250.0417400.1128
..................
89995avg2mutsY396T Y473L4.000.0000000.5832
89996avg2mutsY421W S359K4.000.0446000.5777
89997avg2mutsY449L V503T L335M4.000.0000001.0520
89998avg2mutsY473E L518F D427L4.000.0029181.1600
89999avg2mutsY505N H519T4.000.0000000.3505
\n", "

90000 rows × 5 columns

\n", "
" ], "text/plain": [ " library aa_substitutions concentration prob_escape IC90\n", "0 avg2muts 0.25 0.050440 0.1128\n", "1 avg2muts 0.25 0.143100 0.1128\n", "2 avg2muts 0.25 0.054520 0.1128\n", "3 avg2muts 0.25 0.084730 0.1128\n", "4 avg2muts 0.25 0.041740 0.1128\n", "... ... ... ... ... ...\n", "89995 avg2muts Y396T Y473L 4.00 0.000000 0.5832\n", "89996 avg2muts Y421W S359K 4.00 0.044600 0.5777\n", "89997 avg2muts Y449L V503T L335M 4.00 0.000000 1.0520\n", "89998 avg2muts Y473E L518F D427L 4.00 0.002918 1.1600\n", "89999 avg2muts Y505N H519T 4.00 0.000000 0.3505\n", "\n", "[90000 rows x 5 columns]" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "import polyclonal\n", "\n", "noisy_data = (\n", " pd.read_csv('RBD_variants_escape_noisy.csv', na_filter=None)\n", " .query('library == \"avg2muts\"')\n", " .query('concentration in [0.25, 1, 4]')\n", " .reset_index(drop=True)\n", " )\n", "\n", "noisy_data" ] }, { "cell_type": "markdown", "id": "bc6da22e-9072-49d8-a32c-5c45c37a06aa", "metadata": {}, "source": [ "Initialize a `Polyclonal` model with these data, including three epitopes.\n", "We know from [prior work](https://www.nature.com/articles/s41467-021-24435-8) the three most important epitopes and a key mutation in each, so we use this prior knowledge to \"seed\" initial guesses that assign large escape values to a key site in each epitope:\n", "\n", " - site 417 for class 1 epitope, which is often the least important\n", " - site 484 for class 2 epitope, which is often the dominant one\n", " - site 444 for class 3 epitope, which is often the second most dominant one" ] }, { "cell_type": "code", "execution_count": 2, "id": "5298ee7a-028e-4a80-8d3b-89a721d894f8", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:44:59.879499Z", "iopub.status.busy": "2021-11-19T22:44:59.879354Z", "iopub.status.idle": "2021-11-19T22:45:00.650138Z", "shell.execute_reply": "2021-11-19T22:45:00.649477Z", "shell.execute_reply.started": "2021-11-19T22:44:59.879479Z" }, "tags": [] }, "outputs": [], "source": [ "poly_abs = polyclonal.Polyclonal(data_to_fit=noisy_data,\n", " activity_wt_df=pd.DataFrame.from_records(\n", " [('1', 1.0),\n", " ('2', 3.0),\n", " ('3', 2.0),\n", " ],\n", " columns=['epitope', 'activity'],\n", " ),\n", " site_escape_df=pd.DataFrame.from_records(\n", " [('1', 417, 10.0),\n", " ('2', 484, 10.0),\n", " ('3', 444, 10.0),\n", " ],\n", " columns=['epitope', 'site', 'escape'],\n", " ),\n", " data_mut_escape_overlap='fill_to_data',\n", " )" ] }, { "cell_type": "markdown", "id": "8b75fb60-3317-4940-986a-dfdc9da10fc6", "metadata": {}, "source": [ "Now fit the `Polyclonal` model using the default optimization settings and logging output every 100 steps.\n", "Note how the fitting first just fits a site level model to estimate the average effects of mutations at each site, and then fits the full model:" ] }, { "cell_type": "code", "execution_count": 3, "id": "510d3eff-6f3e-4268-8cac-3fb4b5d37699", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:45:00.652901Z", "iopub.status.busy": "2021-11-19T22:45:00.652787Z", "iopub.status.idle": "2021-11-19T22:48:17.361089Z", "shell.execute_reply": "2021-11-19T22:48:17.359984Z", "shell.execute_reply.started": "2021-11-19T22:45:00.652886Z" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# First fitting site-level model.\n", "# Starting optimization of 522 parameters at Tue Dec 14 09:51:18 2021.\n", " step time_sec loss fit_loss reg_escape regspread\n", " 0 0.058178 9144.4 9144.2 0.29701 0\n", " 100 6.9439 1337.1 1333.6 3.532 0\n", " 200 13.383 1313.1 1308.8 4.3331 0\n", " 300 20.263 1304.4 1299.2 5.1782 0\n", " 400 26.882 1301.1 1295 6.019 0\n", " 500 34.004 1297.9 1291.5 6.3803 0\n", " 600 40.905 1297.3 1290.7 6.5574 0\n", " 700 47.681 1296.5 1289.8 6.7586 0\n", " 800 54.673 1296.1 1289.2 6.8988 0\n", " 900 66.549 1295.6 1288.6 6.9791 0\n", " 1000 73.827 1295.3 1288.3 7.0348 0\n", " 1100 81.151 1295.1 1288 7.1436 0\n", " 1200 88.188 1294.9 1287.6 7.326 0\n", " 1300 95.754 1294.6 1287.1 7.518 0\n", " 1357 99.936 1294.5 1287 7.5161 0\n", "# Successfully finished at Tue Dec 14 09:52:58 2021.\n", "# Starting optimization of 5799 parameters at Tue Dec 14 09:52:58 2021.\n", " step time_sec loss fit_loss reg_escape regspread\n", " 0 0.085596 1646.3 1551.7 94.634 2.2843e-29\n", " 100 8.8431 845.71 738.75 94.388 12.571\n", " 200 17.843 831.92 722.59 93.203 16.12\n", " 300 26.284 823.55 715.48 90.135 17.94\n", " 400 35.075 814.84 709.89 85.955 18.995\n", " 500 43.314 805.74 705.6 80.771 19.365\n", " 600 52.035 797.64 703.24 74.597 19.807\n", " 700 62.589 788.33 702.33 64.906 21.091\n", " 800 73.745 779.89 701.85 57.442 20.598\n", " 900 83.739 773.57 700.36 52.572 20.63\n", " 1000 96.357 769.17 698.58 49.943 20.654\n", " 1100 106.85 763.85 696.45 46.09 21.304\n", " 1200 117.55 756.77 691.6 43.319 21.852\n", " 1300 128.49 752.71 687.58 42.5 22.621\n", " 1400 139.88 748.93 682.52 42.569 23.844\n", " 1500 151.01 744.21 675.52 43.276 25.417\n", " 1600 162.39 737.87 665.84 44.426 27.603\n", " 1700 174.12 733.4 658.15 45.618 29.638\n", " 1800 186.24 728.47 650.56 46.77 31.138\n", " 1900 197.15 719.11 636.75 48.475 33.883\n", " 2000 208.01 705.27 618.49 50.204 36.569\n", " 2100 218.28 686.35 597.18 51.389 37.787\n", " 2200 229.13 673.08 584.12 51.914 37.044\n", " 2300 239.41 667.73 578.73 52.328 36.667\n", " 2400 251.85 666.16 577.31 52.558 36.291\n", " 2500 262.01 665.63 576.75 52.767 36.109\n", " 2600 272.64 665.39 576.32 52.995 36.073\n", " 2700 283.95 665.28 576.22 53.059 35.994\n", " 2723 287.47 665.26 576.22 53.063 35.979\n", "# Successfully finished at Tue Dec 14 09:57:45 2021.\n" ] } ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "opt_res, lossreg = poly_abs.fit(logfreq=100)" ] }, { "cell_type": "markdown", "id": "fdbff389", "metadata": {}, "source": [ "## Let's make some steps towards prox grad" ] }, { "cell_type": "code", "execution_count": 4, "id": "fbae01f1", "metadata": {}, "outputs": [], "source": [ "previously_fit_params = poly_abs._params" ] }, { "cell_type": "markdown", "id": "f079f95c", "metadata": {}, "source": [ "We make some shim functions that will allow us to use Will's code.\n", "\n", "Note that we have no actual new regularization happening: we're just trying to use Will's code to fit the existing objective." ] }, { "cell_type": "code", "execution_count": 5, "id": "6219d5e4", "metadata": {}, "outputs": [], "source": [ "def g_shim(params):\n", " return lossreg.loss_reg(params)[0]\n", "\n", "def grad_shim(params):\n", " return lossreg.loss_reg(params)[1]\n", "\n", "def zero_function(params):\n", " return 0.\n", "\n", "def trivial_prox(params, t):\n", " return params" ] }, { "cell_type": "code", "execution_count": 6, "id": "4ce897fe", "metadata": {}, "outputs": [], "source": [ "from polyclonal import optimization\n", "\n", "prox_grad = optimization.AccProxGrad(g_shim, grad_shim, zero_function, trivial_prox, verbose=True)" ] }, { "cell_type": "markdown", "id": "fdde8872", "metadata": {}, "source": [ "When we try previously fit params, nothing happens, which is a good thing." ] }, { "cell_type": "code", "execution_count": 7, "id": "adb715a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial objective 6.652598e+02\n", "iteration 1, objective 6.653e+02, relative change 8.699e-08 \n", "relative change in objective function 8.7e-08 is within tolerance 1e-06 after 1 iterations\n" ] }, { "data": { "text/plain": [ "array([1.06698781, 3.22729169, 1.94937872, ..., 0.35142203, 0.69839784,\n", " 0.3097154 ])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prox_grad.run(previously_fit_params)" ] }, { "cell_type": "markdown", "id": "d588eb8f", "metadata": {}, "source": [ "Let's try with some stupid starting parameters: all ones.\n", "\n", "This stops after less than a thousand iterations. Perhaps it hit a local minimum." ] }, { "cell_type": "code", "execution_count": 9, "id": "c036a5ac", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial objective 1.782483e+04\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ematsen/re/polyclonal/polyclonal/polyclonal.py:1583: RuntimeWarning: overflow encountered in exp\n", " exp_minus_phi_e_v = numpy.exp(-phi_e_v)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "iteration 636, objective 1.108e+03, relative change 9.955e-07 \n", "relative change in objective function 1e-06 is within tolerance 1e-06 after 636 iterations\n" ] }, { "data": { "text/plain": [ "array([2.27645376, 2.27645376, 2.27645376, ..., 0.54204399, 0.54204399,\n", " 0.54204399])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "new_prox_grad = optimization.AccProxGrad(g_shim, grad_shim, zero_function, trivial_prox, verbose=True)\n", "new_params = np.ones(previously_fit_params.shape[0])\n", "new_prox_grad.run(new_params, max_iter=3000)" ] }, { "cell_type": "markdown", "id": "d9b7f7c3", "metadata": {}, "source": [ "OK, let's try something easier: start with just a perturbation of the parameters. This works!!" ] }, { "cell_type": "code", "execution_count": 10, "id": "5dd5fab4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.97677739, 3.30086873, 1.76248978, ..., 0.32171498, 0.73367543,\n", " 0.28246869])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_params = previously_fit_params\n", "new_params *= np.random.uniform(0.9, 1.1, size = new_params.shape[0])\n", "new_params" ] }, { "cell_type": "code", "execution_count": 11, "id": "18916af4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial objective 7.195510e+02\n", "iteration 766, objective 6.654e+02, relative change 9.976e-07 \n", "relative change in objective function 1e-06 is within tolerance 1e-06 after 766 iterations\n" ] }, { "data": { "text/plain": [ "array([1.06829617, 3.2274661 , 1.94871031, ..., 0.36086176, 0.69005368,\n", " 0.32927252])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_prox_grad.run(new_params, max_iter=3000)" ] }, { "cell_type": "markdown", "id": "7b18fc48", "metadata": {}, "source": [ "Interestingly, we don't seem to get the same parameters." ] }, { "cell_type": "code", "execution_count": 12, "id": "ab41a6b9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.09151878, -0.07340263, 0.18622053, ..., 0.03914678,\n", " -0.04362175, 0.04680383])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_prox_grad.x - previously_fit_params" ] }, { "cell_type": "code", "execution_count": 13, "id": "986d3554", "metadata": {}, "outputs": [], "source": [ "import copy\n", "new_poly_abs = copy.deepcopy(poly_abs)" ] }, { "cell_type": "code", "execution_count": 14, "id": "e6e6b253", "metadata": {}, "outputs": [], "source": [ "new_poly_abs._params = new_prox_grad.x" ] }, { "cell_type": "markdown", "id": "11cd6f79", "metadata": {}, "source": [ "BUT, the fit is very close to optimal according to polyclonal." ] }, { "cell_type": "code", "execution_count": 15, "id": "0a6b44ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# Starting optimization of 5799 parameters at Tue Dec 14 10:03:40 2021.\n", " step time_sec loss fit_loss reg_escape regspread\n", " 0 0.11994 665.37 576.04 53.15 36.181\n", " 100 10.045 665.23 576.15 53.124 35.96\n", " 101 10.047 665.23 576.15 53.124 35.96\n", "# Successfully finished at Tue Dec 14 10:03:50 2021.\n" ] }, { "data": { "text/plain": [ "( fun: 665.2345519341873\n", " hess_inv: <5799x5799 LbfgsInvHessProduct with dtype=float64>\n", " jac: array([ 0.59013223, 0.93013372, 0.90565914, ..., -0.00113136,\n", " -0.00192544, -0.00261379])\n", " message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'\n", " nfev: 116\n", " nit: 100\n", " njev: 116\n", " status: 0\n", " success: True\n", " x: array([1.06887108, 3.22690596, 1.94935764, ..., 0.3569214 , 0.69360593,\n", " 0.31413904]),\n", " .LossReg at 0x79237ab84ca0>)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_poly_abs.fit(logfreq=100, fit_site_level_first=False)" ] }, { "cell_type": "markdown", "id": "fc42c153", "metadata": {}, "source": [ "I wonder what happens if we start with the site-level optimization first. I couldn't get this to work right away." ] }, { "cell_type": "code", "execution_count": 16, "id": "2d7fc66a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# First fitting site-level model.\n" ] }, { "ename": "ValueError", "evalue": "invalid mutation w331m", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0msite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpoly_abs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msite_level_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msite_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogfreq\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/re/polyclonal/polyclonal/polyclonal.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, loss_delta, reg_escape_weight, reg_escape_delta, reg_spread_weight, fit_site_level_first, scipy_minimize_kwargs, log, logfreq)\u001b[0m\n\u001b[1;32m 1193\u001b[0m \u001b[0mfit_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkeys\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'self'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1194\u001b[0m \u001b[0mfit_kwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'fit_site_level_first'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1195\u001b[0;31m \u001b[0msite_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msite_level_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1196\u001b[0m \u001b[0msite_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mfit_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1197\u001b[0m self._params = self._params_from_dfs(\n", "\u001b[0;32m~/re/polyclonal/polyclonal/polyclonal.py\u001b[0m in \u001b[0;36msite_level_model\u001b[0;34m(self, aggregate_mut_escapes)\u001b[0m\n\u001b[1;32m 998\u001b[0m )\n\u001b[1;32m 999\u001b[0m site_escape_df = (\n\u001b[0;32m-> 1000\u001b[0;31m polyclonal.utils.site_level_variants(\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmut_escape_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1002\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrename\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'mutation'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'aa_substitutions'\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/re/polyclonal/polyclonal/utils.py\u001b[0m in \u001b[0;36msite_level_variants\u001b[0;34m(df, original_alphabet, wt_char, mut_char)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0msite_subs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0msub\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msubs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mwt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msite\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmutparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse_mut\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msub\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msite\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mwts\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mwts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msite\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mwt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m raise ValueError(f\"inconsistent wildtype at {site}: \"\n", "\u001b[0;32m~/re/polyclonal/polyclonal/utils.py\u001b[0m in \u001b[0;36mparse_mut\u001b[0;34m(self, mutation)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mutation_regex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfullmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmutation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid mutation {mutation}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'wt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'site'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'mut'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: invalid mutation w331m" ] } ], "source": [ "site_model = poly_abs.site_level_model()\n", "site_model.fit(logfreq=100)" ] }, { "cell_type": "code", "execution_count": null, "id": "231e0a3f", "metadata": {}, "outputs": [], "source": [ "site_fit_params = poly_abs._params_from_dfs(\n", " activity_wt_df=site_model.activity_wt_df,\n", " mut_escape_df=(\n", " site_model.mut_escape_df\n", " [['epitope', 'site', 'escape']]\n", " .merge(poly_abs.mut_escape_df.drop(columns='escape'),\n", " on=['epitope', 'site'],\n", " how='right',\n", " validate='one_to_many',\n", " )\n", " ),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "9603e020", "metadata": {}, "outputs": [], "source": [ "new_prox_grad.run(new_params)" ] }, { "cell_type": "markdown", "id": "31c2c9c9-8c6b-40d1-94d7-711e42f4bdd4", "metadata": {}, "source": [ "We can now visualize the resulting fits for the activities and escape values, and they can be compared back to the earlier \"true\" results used to simulate the data:" ] }, { "cell_type": "code", "execution_count": null, "id": "c8d114a7-cb14-4a0f-86a8-a5b9b69cdd42", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:17.363285Z", "iopub.status.busy": "2021-11-19T22:48:17.363038Z", "iopub.status.idle": "2021-11-19T22:48:17.678161Z", "shell.execute_reply": "2021-11-19T22:48:17.677437Z", "shell.execute_reply.started": "2021-11-19T22:48:17.363267Z" }, "tags": [] }, "outputs": [], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "poly_abs.activity_wt_barplot()" ] }, { "cell_type": "code", "execution_count": null, "id": "e88070ba-ecc3-43f4-8660-4564e235d89f", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:17.679208Z", "iopub.status.busy": "2021-11-19T22:48:17.679048Z", "iopub.status.idle": "2021-11-19T22:48:21.928677Z", "shell.execute_reply": "2021-11-19T22:48:21.927705Z", "shell.execute_reply.started": "2021-11-19T22:48:17.679190Z" }, "tags": [] }, "outputs": [], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "poly_abs.mut_escape_lineplot()" ] }, { "cell_type": "code", "execution_count": null, "id": "b564e1ed-3734-4059-acbb-e45e40626997", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:21.930034Z", "iopub.status.busy": "2021-11-19T22:48:21.929819Z", "iopub.status.idle": "2021-11-19T22:48:25.576577Z", "shell.execute_reply": "2021-11-19T22:48:25.575765Z", "shell.execute_reply.started": "2021-11-19T22:48:21.930009Z" } }, "outputs": [], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "poly_abs.mut_escape_heatmap()" ] }, { "cell_type": "markdown", "id": "dc0d7894-36f5-4e0b-85f2-9f96f4e22775", "metadata": {}, "source": [ "For these simulated data, we can also see how well the fit model does on the \"true\" simulated values from a library with a different (higher) mutation rate.\n", "We therefore read in the \"exact\" simulated data from a library with a different mutation rate:" ] }, { "cell_type": "code", "execution_count": null, "id": "70b08e11-2cfc-4978-8543-aa3fba9542be", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:25.577612Z", "iopub.status.busy": "2021-11-19T22:48:25.577373Z", "iopub.status.idle": "2021-11-19T22:48:25.876463Z", "shell.execute_reply": "2021-11-19T22:48:25.875578Z", "shell.execute_reply.started": "2021-11-19T22:48:25.577593Z" } }, "outputs": [], "source": [ "exact_data = (\n", " pd.read_csv('RBD_variants_escape_exact.csv', na_filter=None)\n", " .query('library == \"avg3muts\"')\n", " .query('concentration in [0.25, 1, 0.5]')\n", " .reset_index(drop=True)\n", " )" ] }, { "cell_type": "markdown", "id": "93664ba9-2b0c-4ca2-a73c-008ab954cc13", "metadata": {}, "source": [ "First, we will compare the true simulated IC90 values to those predicted by the fit model.\n", "We make the comparison on a log scale, and clip IC90s at values >50 as likely to be way outside the dynamic range given the concentrations used:" ] }, { "cell_type": "code", "execution_count": null, "id": "eecc461e-dae7-42ec-b3b8-c2b2ac8d9a63", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:25.877632Z", "iopub.status.busy": "2021-11-19T22:48:25.877427Z", "iopub.status.idle": "2021-11-19T22:48:30.891084Z", "shell.execute_reply": "2021-11-19T22:48:30.890219Z", "shell.execute_reply.started": "2021-11-19T22:48:25.877612Z" }, "tags": [] }, "outputs": [], "source": [ "import numpy\n", "\n", "from plotnine import *\n", "\n", "max_ic90 = 50\n", "\n", "# we only need the variants, not the concentration for the IC90 comparison\n", "ic90s = (exact_data[['aa_substitutions', 'IC90']]\n", " .assign(IC90=lambda x: x['IC90'].clip(upper=max_ic90))\n", " .drop_duplicates()\n", " )\n", "\n", "ic90s = poly_abs.icXX(ic90s, x=0.9, col='predicted_IC90', max_c=max_ic90)\n", "\n", "ic90s = (\n", " ic90s\n", " .assign(log_IC90=lambda x: numpy.log10(x['IC90']),\n", " predicted_log_IC90=lambda x: numpy.log10(x['predicted_IC90']),\n", " )\n", " )\n", "\n", "corr = ic90s['log_IC90'].corr(ic90s['predicted_log_IC90'])\n", "print(f\"Correlation is {corr:.2f}\")\n", "\n", "ic90_corr_plot = (\n", " ggplot(ic90s) +\n", " aes('log_IC90', 'predicted_log_IC90') +\n", " geom_point(alpha=0.1, size=1) +\n", " theme_classic() +\n", " theme(figure_size=(3, 3))\n", " )\n", "\n", "_ = ic90_corr_plot.draw()" ] }, { "cell_type": "markdown", "id": "3bc4ee4b-91f4-4d7a-b02b-0ab59b52f259", "metadata": {}, "source": [ "Next we see how well the model predicts the variant-level escape probabilities $p_v\\left(c\\right)$, by reading in exact data from the simulations, and then making predictions of escape probabilities.\n", "We both examine and plot the correlations:" ] }, { "cell_type": "code", "execution_count": null, "id": "327b04a6-680a-4711-b1de-1a23dc1f9983", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:30.892402Z", "iopub.status.busy": "2021-11-19T22:48:30.892107Z", "iopub.status.idle": "2021-11-19T22:48:34.267781Z", "shell.execute_reply": "2021-11-19T22:48:34.266933Z", "shell.execute_reply.started": "2021-11-19T22:48:30.892381Z" }, "tags": [] }, "outputs": [], "source": [ "exact_vs_pred = poly_abs.prob_escape(variants_df=exact_data)\n", "\n", "print(f\"Correlations at each concentration:\")\n", "display(exact_vs_pred\n", " .groupby('concentration')\n", " .apply(lambda x: x['prob_escape'].corr(x['predicted_prob_escape']))\n", " .rename('correlation')\n", " .reset_index()\n", " .round(2)\n", " )\n", "\n", "pv_corr_plot = (\n", " ggplot(exact_vs_pred) +\n", " aes('prob_escape', 'predicted_prob_escape') +\n", " geom_point(alpha=0.1, size=1) +\n", " facet_wrap('~ concentration', nrow=1) +\n", " theme_classic() +\n", " theme(figure_size=(3 * exact_vs_pred['concentration'].nunique(), 3))\n", " )\n", "\n", "_ = pv_corr_plot.draw()" ] }, { "cell_type": "markdown", "id": "b6964286-e1f3-4eb7-aab0-816c12343124", "metadata": { "execution": { "iopub.execute_input": "2021-11-18T14:00:53.858336Z", "iopub.status.busy": "2021-11-18T14:00:53.857951Z", "iopub.status.idle": "2021-11-18T14:00:54.830072Z", "shell.execute_reply": "2021-11-18T14:00:54.829519Z", "shell.execute_reply.started": "2021-11-18T14:00:53.858294Z" }, "tags": [] }, "source": [ "We also examine the correlation between the \"true\" and inferred mutation-escape values, $\\beta_{m,e}$.\n", "In general, it's necessary to ensure the epitopes match up for this type of comparison as it is arbitrary which epitope in the model is given which name.\n", "But above we seeded the epitopes at the site level using `site_effects_df` when we initialized the `Polyclonal` object, so they match up with class 1, 2, and 3:" ] }, { "cell_type": "code", "execution_count": null, "id": "653928f4-fed2-4fd2-9d52-0c9411f7e9a5", "metadata": { "execution": { "iopub.execute_input": "2021-11-19T22:48:34.269060Z", "iopub.status.busy": "2021-11-19T22:48:34.268765Z", "iopub.status.idle": "2021-11-19T22:48:34.755956Z", "shell.execute_reply": "2021-11-19T22:48:34.755148Z", "shell.execute_reply.started": "2021-11-19T22:48:34.269036Z" }, "tags": [] }, "outputs": [], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "\n", "import altair as alt\n", "\n", "mut_escape_pred = (\n", " pd.read_csv('RBD_mut_escape_df.csv')\n", " .merge((poly_abs.mut_escape_df\n", " .assign(epitope=lambda x: 'class ' + x['epitope'].astype(str))\n", " .rename(columns={'escape': 'predicted escape'})\n", " ),\n", " on=['mutation', 'epitope'],\n", " validate='one_to_one',\n", " )\n", " )\n", "\n", "print('Correlation between predicted and true values:')\n", "corr = (mut_escape_pred\n", " .groupby('epitope')\n", " .apply(lambda x: x['escape'].corr(x['predicted escape']))\n", " .rename('correlation')\n", " .reset_index()\n", " )\n", "display(corr.round(2))\n", "\n", "# for testing since we nbval ignore cell output\n", "numpy.allclose(corr['correlation'], numpy.array([0.82, 0.96, 0.93]), atol=0.02)\n", "\n", "corr_chart = (\n", " alt.Chart(mut_escape_pred)\n", " .encode(x='escape',\n", " y='predicted escape',\n", " color='epitope',\n", " tooltip=['mutation', 'epitope'],\n", " )\n", " .mark_point(opacity=0.5)\n", " .properties(width=250, height=250)\n", " .facet(column='epitope')\n", " .resolve_scale(x='independent',\n", " y='independent',\n", " )\n", " )\n", "\n", "corr_chart" ] }, { "cell_type": "markdown", "id": "595a73ce-aaa7-439b-bbdc-6d3c7c23ab1c", "metadata": {}, "source": [ "The correlations are strongest for the dominant epitope (class 2), which makes sense as this will drive the highest escape signal." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }