{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "China_changing_point.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "kqXl71t9suiG" }, "source": [ "# Estimating the Date of COVID-19 Changes\n", "\n", "https://nbviewer.jupyter.org/github/jramkiss/jramkiss.github.io/blob/master/_posts/notebooks/covid19-changes.ipynb " ] }, { "cell_type": "code", "metadata": { "id": "gFnvD8OysuiI", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "bf23d50c-7027-4c1d-d67e-4cf0525584da" }, "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "import seaborn as sns; sns.set()\n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as mdates\n", "\n", "\n", "from sklearn.linear_model import LinearRegression\n", "\n", "from scipy import stats\n", "import statsmodels.api as sm\n", "import pylab\n", "\n", "# from google.colab import files\n", "# from io import StringIO\n", "# uploaded = files.upload()\n", "\n", "url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/full_grouped.csv'\n", "data = pd.read_csv(url)\n", "\n", "data.Date = pd.to_datetime(data.Date)\n", "\n", "# for fancy python printing\n", "from IPython.display import Markdown, display\n", "def printmd(string):\n", " display(Markdown(string))\n", " \n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import matplotlib as mpl\n", "mpl.rcParams['figure.dpi'] = 250" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", " import pandas.util.testing as tm\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "hzvPpvVvphTD" }, "source": [ "## Create country\n" ] }, { "cell_type": "code", "metadata": { "id": "koX5yGHrsuib" }, "source": [ "# function to make the time series of confirmed and daily confirmed cases for a specific country\n", "def create_country (country, end_date, state = False) : \n", " if state :\n", " df = data.loc[data[\"Province/State\"] == country, [\"Province/State\", \"Date\", \"Confirmed\", \"Deaths\", \"Recovered\"]]\n", " else : \n", " df = data.loc[data[\"Country/Region\"] == country, [\"Country/Region\", \"Date\", \"Confirmed\", \"Deaths\", \"Recovered\"]]\n", " df.columns = [\"country\", \"date\", \"confirmed\", \"deaths\", \"recovered\"]\n", "\n", " # group by country and date, sum(confirmed, deaths, recovered). do this because countries have multiple cities \n", " df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered'].sum().reset_index()\n", "\n", " # convert date string to datetime\n", " std_dateparser = lambda x: str(x)[5:10]\n", " df.date = pd.to_datetime(df.date)\n", " df['date_only'] = df.date.apply(std_dateparser)\n", " df = df.sort_values(by = \"date\")\n", " df = df[df.date <= end_date]\n", "\n", "\n", " # make new confirmed cases every day:\n", " cases_shifted = np.array([0] + list(df.confirmed[:-1]))\n", " daily_confirmed = np.array(df.confirmed) - cases_shifted\n", " df[\"daily_confirmed\"] = daily_confirmed \n", "\n", " fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 6))\n", " ax = [ax]\n", " sns.lineplot(x = df.date, \n", " y = df.daily_confirmed, \n", " ax = ax[0])\n", "\n", " ax[0].set(ylabel='Daily Confirmed Cases')\n", "\n", " ax[0].axvline(pd.to_datetime('2020-01-22'), \n", " linestyle = '--', linewidth = 1.5,\n", " label = \"Date of Lockdown: Jan 22, 2020\" ,\n", " color = \"red\") \n", "\n", " \n", " # ax[0].set_ylim([100,150000])\n", " ax[0].xaxis.get_label().set_fontsize(16)\n", " ax[0].yaxis.get_label().set_fontsize(16)\n", " ax[0].title.set_fontsize(20)\n", " ax[0].tick_params(labelsize=16)\n", " myFmt = mdates.DateFormatter('%b %-d')\n", " ax[0].xaxis.set_major_formatter(myFmt)\n", "\n", " ax[0].set(ylabel='Daily Confirmed Cases', xlabel='');\n", " ax[0].legend(loc = \"bottom right\", fontsize=12.8)\n", " sns.set_style(\"ticks\")\n", " plt.tight_layout()\n", " sns.despine()\n", " plt.savefig('/content/sample_data/china_daily.pdf')\n", " print(df.tail())\n", " return df\n", "\n", "\n", "def summary(samples):\n", " site_stats = {}\n", " for k, v in samples.items():\n", " site_stats[k] = {\n", " \"mean\": torch.mean(v, 0),\n", " \"std\": torch.std(v, 0),\n", " \"5%\": v.kthvalue(int(len(v) * 0.05), dim=0)[0],\n", " \"95%\": v.kthvalue(int(len(v) * 0.95), dim=0)[0],\n", " }\n", " return site_stats" ], "execution_count": 9, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 577 }, "id": "w_A0fd4Zsuiw", "outputId": "a8ec9956-f5df-44cc-959b-1c21d32ed580" }, "source": [ "cad = create_country(\"China\", end_date = \"2020-04-01\")" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ " country date confirmed ... date_only daily_confirmed moving_avg\n", "66 China 2020-03-28 81999 ... 03-28 102 102.00\n", "67 China 2020-03-29 82122 ... 03-29 123 115.25\n", "68 China 2020-03-30 82198 ... 03-30 76 104.00\n", "69 China 2020-03-31 82279 ... 03-31 81 95.50\n", "70 China 2020-04-01 82361 ... 04-01 82 90.50\n", "\n", "[5 rows x 8 columns]\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "UR0BM7TysujG" }, "source": [ "cad_start = \"2020-01-22\" # 13 confirmed cases\n", "cad = cad[cad.date >= cad_start].reset_index(drop = True)\n", "cad[\"days_since_start\"] = np.arange(cad.shape[0]) + 1" ], "execution_count": 15, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "OaTHo6I2sujp", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "cfe9b980-3356-4e17-bb4c-a8a45c271a8d" }, "source": [ "cad.shape\n", "cad_tmp = cad[cad.date < \"2020-03-16\"]\n", "cad_tmp.shape" ], "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(54, 9)" ] }, "metadata": { "tags": [] }, "execution_count": 16 } ] }, { "cell_type": "markdown", "metadata": { "id": "loi3CtjSsuoz" }, "source": [ "## Data for Regression" ] }, { "cell_type": "code", "metadata": { "id": "Os_M7r4Tsuo4" }, "source": [ "# variable for data to easily swap it out:\n", "country_ = \"China\"\n", "reg_data = cad_tmp.copy()" ], "execution_count": 17, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "id": "3RjDFEbA91X-", "outputId": "d6ce0220-1f77-491b-b3b9-67556f08e702" }, "source": [ "reg_data.head()" ], "execution_count": 18, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countrydateconfirmeddeathsrecovereddate_onlydaily_confirmedmoving_avgdays_since_start
0China2020-01-22548172801-22548NaN1
1China2020-01-23643183001-2395NaN2
2China2020-01-24920263601-24277NaN3
3China2020-01-251406423901-25486351.504
4China2020-01-262075564901-26669381.755
\n", "
" ], "text/plain": [ " country date confirmed ... daily_confirmed moving_avg days_since_start\n", "0 China 2020-01-22 548 ... 548 NaN 1\n", "1 China 2020-01-23 643 ... 95 NaN 2\n", "2 China 2020-01-24 920 ... 277 NaN 3\n", "3 China 2020-01-25 1406 ... 486 351.50 4\n", "4 China 2020-01-26 2075 ... 669 381.75 5\n", "\n", "[5 rows x 9 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 18 } ] }, { "cell_type": "markdown", "metadata": { "id": "JkO0Z8M0supC" }, "source": [ "## Change Point Estimation in Pyro" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aIUed4Ny3-oq", "outputId": "f2980893-5246-40f0-b290-1f4cd5718aff" }, "source": [ "!pip install pyro-ppl\n", "!pip install numpyro" ], "execution_count": 19, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: pyro-ppl in /usr/local/lib/python3.7/dist-packages (1.6.0)\n", "Requirement already satisfied: tqdm>=4.36 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (4.41.1)\n", "Requirement already satisfied: torch>=1.8.0 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.8.1+cu101)\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (3.3.0)\n", "Requirement already satisfied: pyro-api>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (0.1.2)\n", "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.7/dist-packages (from pyro-ppl) (1.19.5)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.8.0->pyro-ppl) (3.7.4.3)\n", "Requirement already satisfied: numpyro in /usr/local/lib/python3.7/dist-packages (0.6.0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.41.1)\n", "Requirement already satisfied: jaxlib==0.1.62 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.1.62)\n", "Requirement already satisfied: jax==0.2.10 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.2.10)\n", "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (1.12)\n", "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (1.19.5)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (0.12.0)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib==0.1.62->numpyro) (1.4.1)\n", "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax==0.2.10->numpyro) (3.3.0)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jaxlib==0.1.62->numpyro) (1.15.0)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "3ZS9fTPxsupD" }, "source": [ "import torch\n", "\n", "import pyro\n", "import pyro.distributions as dist\n", "from torch import nn\n", "from pyro.nn import PyroModule, PyroSample\n", "\n", "from pyro.infer import MCMC, NUTS, HMC\n", "from pyro.infer.autoguide import AutoGuide, AutoDiagonalNormal\n", "\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.infer import Predictive" ], "execution_count": 20, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9gPrEaEJsupP" }, "source": [ "# we should be able to have an empirical estimate for the mean of the prior for the 2nd regression bias term\n", "# this will be something like b = log(max(daily_confirmed))\n", "\n", "# might be able to have 1 regression model but change the data so that we have new terms for (tau < t) \n", "# like an interaction term\n", "\n", "class COVID_change(PyroModule):\n", " def __init__(self, in_features, out_features, b1_mu, b2_mu):\n", " super().__init__()\n", " self.linear1 = PyroModule[nn.Linear](in_features, out_features, bias = False)\n", " self.linear1.weight = PyroSample(dist.Normal(0.5, 0.25).expand([1, 1]).to_event(1))\n", " self.linear1.bias = PyroSample(dist.Normal(b1_mu, 1.))\n", " \n", " # could possibly have stronger priors for the 2nd regression line, because we wont have as much data\n", " self.linear2 = PyroModule[nn.Linear](in_features, out_features, bias = False)\n", " self.linear2.weight = PyroSample(dist.Normal(0., 0.25).expand([1, 1])) #.to_event(1))\n", " self.linear2.bias = PyroSample(dist.Normal(b2_mu, b2_mu/4))\n", "\n", " def forward(self, x, y=None):\n", " tau = pyro.sample(\"tau\", dist.Beta(4, 3))\n", " sigma = pyro.sample(\"sigma\", dist.Uniform(0., 3.))\n", " # fit lm's to data based on tau\n", " sep = int(np.ceil(tau.detach().numpy() * len(x)))\n", " mean1 = self.linear1(x[:sep]).squeeze(-1)\n", " mean2 = self.linear2(x[sep:]).squeeze(-1)\n", " mean = torch.cat((mean1, mean2))\n", " obs = pyro.sample(\"obs\", dist.StudentT(2, mean, sigma), obs=y)\n", " return mean" ], "execution_count": 21, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "if0toOMysupU", "outputId": "64de9904-30aa-439f-d19a-9ca6970b5850" }, "source": [ "tensor_data = torch.tensor(reg_data[[\"confirmed\", \"days_since_start\", \"daily_confirmed\"]].values, dtype=torch.float)\n", "x_data = tensor_data[:, 1].unsqueeze_(1)\n", "y_data = np.log(tensor_data[:, 0])\n", "y_data_daily = np.log(tensor_data[:, 2])\n", "# prior hyper params\n", "# take log of the average of the 1st quartile to get the prior mean for the bias of the 2nd regression line\n", "q1 = np.quantile(y_data, q = 0.25)\n", "bias_1_mean = np.mean(y_data.numpy()[y_data <= q1])\n", "print(\"Prior mean for Bias 1: \", bias_1_mean)\n", "\n", "# take log of the average of the 4th quartile to get the prior mean for the bias of the 2nd regression line\n", "q4 = np.quantile(y_data, q = 0.75)\n", "bias_2_mean = np.mean(y_data.numpy()[y_data >= q4])\n", "print(\"Prior mean for Bias 2: \", bias_2_mean)" ], "execution_count": 22, "outputs": [ { "output_type": "stream", "text": [ "Prior mean for Bias 1: 8.359699\n", "Prior mean for Bias 2: 11.29878\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "nrm8RrFasupc" }, "source": [ "## Approximate Inference with Stochastic Variational Inference" ] }, { "cell_type": "markdown", "metadata": { "id": "x0nDLSPisupm" }, "source": [ "# HMC with NUTS" ] }, { "cell_type": "code", "metadata": { "id": "X1rSXXtKsupm" }, "source": [ "model = COVID_change(1, 1, \n", " b1_mu = bias_1_mean,\n", " b2_mu = bias_2_mean)\n", "# need more than 400 samples/chain if we want to use a flat prior on b_2 and w_2\n", "num_samples = 400 \n", "# mcmc \n", "nuts_kernel = NUTS(model)\n", "mcmc = MCMC(nuts_kernel, \n", " num_samples=num_samples,\n", " warmup_steps = 200, \n", " num_chains = 1)\n", "mcmc.run(x_data, y_data)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "RJPSAgdbLTTJ" }, "source": [ "# Save the model:\n", "import dill\n", "# with open('china.pkl', 'wb') as f:\n", "# \tdill.dump(mcmc, f)\n", "with open('china.pkl', 'rb') as f:\n", "\tmcmc = dill.load(f)\n", " \n", "samples = mcmc.get_samples()" ], "execution_count": 24, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7Z968a5xsupv", "outputId": "eb4f34cd-e16a-4315-dcbc-40e23bc7d482" }, "source": [ "# extract individual posteriors\n", "weight_1_post = samples[\"linear1.weight\"].detach().numpy()\n", "weight_2_post = samples[\"linear2.weight\"].detach().numpy()\n", "bias_1_post = samples[\"linear1.bias\"].detach().numpy()\n", "bias_2_post = samples[\"linear2.bias\"].detach().numpy()\n", "tau_post = samples[\"tau\"].detach().numpy()\n", "sigma_post = samples[\"sigma\"].detach().numpy()\n", "\n", "# build likelihood distribution:\n", "tau_days = list(map(int, np.ceil(tau_post * len(x_data))))\n", "mean_ = torch.zeros(len(tau_days), len(x_data))\n", "obs_ = torch.zeros(len(tau_days), len(x_data))\n", "for i in range(len(tau_days)) : \n", " mean_[i, :] = torch.cat((x_data[:tau_days[i]] * weight_1_post[i] + bias_1_post[i],\n", " x_data[tau_days[i]:] * weight_2_post[i] + bias_2_post[i])).reshape(len(x_data))\n", " obs_[i, :] = dist.Normal(mean_[i, :], sigma_post[i]).sample()\n", "samples[\"_RETURN\"] = mean_\n", "samples[\"obs\"] = obs_\n", "pred_summary = summary(samples)\n", "mu = pred_summary[\"_RETURN\"] # mean\n", "y = pred_summary[\"obs\"] # samples from likelihood: mu + sigma\n", "y_shift = np.exp(y[\"mean\"]) - np.exp(torch.cat((y[\"mean\"][0:1], y[\"mean\"][:-1])))\n", "print(y_shift)\n", "predictions = pd.DataFrame({\n", " \"days_since_start\": x_data[:, 0],\n", " \"mu_mean\": mu[\"mean\"], # mean of likelihood\n", " \"mu_perc_5\": mu[\"5%\"],\n", " \"mu_perc_95\": mu[\"95%\"],\n", " \"y_mean\": y[\"mean\"], # mean of likelihood + noise\n", " \"y_perc_5\": y[\"5%\"],\n", " \"y_perc_95\": y[\"95%\"],\n", " \"true_confirmed\": y_data,\n", " \"true_daily_confirmed\": y_data_daily,\n", " \"y_daily_mean\": y_shift\n", "})\n", "\n", "w1_ = pred_summary[\"linear1.weight\"]\n", "w2_ = pred_summary[\"linear2.weight\"]\n", "\n", "b1_ = pred_summary[\"linear1.bias\"]\n", "b2_ = pred_summary[\"linear2.bias\"]\n", "\n", "tau_ = pred_summary[\"tau\"]\n", "sigma_ = pred_summary[\"sigma\"]\n", "\n", "ind = int(np.ceil(tau_[\"mean\"] * len(x_data)))" ], "execution_count": 25, "outputs": [ { "output_type": "stream", "text": [ "tensor([ 0.0000e+00, 2.2725e+02, 3.0094e+02, 4.0804e+02, 5.1011e+02,\n", " 6.8209e+02, 9.4119e+02, 1.2054e+03, 1.5588e+03, 2.1987e+03,\n", " 2.8797e+03, 3.6862e+03, 4.8820e+03, 6.5520e+03, 9.3445e+03,\n", " 1.1462e+04, 1.3558e+04, 5.6779e+03, 1.5542e+03, 1.3815e+02,\n", " 9.5227e+02, -2.4232e+02, 2.4934e+02, 7.3088e+02, 5.2862e+02,\n", " 7.5919e+02, -1.0573e+02, 3.9531e+00, 8.0497e+02, 9.6861e+02,\n", " -5.0664e+01, 5.5251e+02, 9.4084e+02, -2.7266e+02, 1.3263e+03,\n", " 3.4009e+02, 1.9668e+02, 1.1166e+03, -3.3041e+02, 7.6644e+02,\n", " 5.8625e+01, 1.3439e+03, -1.9008e+01, 7.6425e+02, 6.2384e+02,\n", " 1.4739e+02, 1.3451e+02, 9.4297e+02, 7.5945e+02, 1.1752e+03,\n", " -9.4339e+02, 1.7658e+03, 2.9969e+01, 4.2532e+02])\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "KegzMbLOsuqC" }, "source": [ "## Model Diagnostics\n", "\n", "- Residual plots: Should these be samples from the likelihood compared with the actual data? Or just the mean of the likelihood?\n", "- $\\hat{R}$: The factor that the scale of the current distribution will be reduced by if we were to run the simulations forever. As n tends to $\\inf$, $\\hat{R}$ tends to 1. So we want values close to 1.\n", "- Mixing and Stationarity: I sampled 4 chains. Do I then take these chains, split them in half and plot them. If they converge to the same stationary distribution, does that mean the MCMC converged? What do I do with more sampled chains?" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QUh6fBjtsuqV", "outputId": "a3304748-0d50-4071-d29d-f5cdcabcdadf" }, "source": [ "mcmc.summary()\n", "diag = mcmc.diagnostics()" ], "execution_count": 26, "outputs": [ { "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " tau 0.31 0.03 0.31 0.26 0.35 15.20 1.21\n", " sigma 0.08 0.02 0.08 0.05 0.11 55.90 1.01\n", "linear1.weight[0,0] 0.28 0.03 0.29 0.21 0.31 8.85 1.21\n", " linear1.bias 6.27 0.33 6.15 5.98 7.08 8.38 1.20\n", "linear2.weight[0,0] 0.01 0.00 0.01 0.00 0.01 22.25 1.10\n", " linear2.bias 11.00 0.07 11.00 10.90 11.10 23.06 1.10\n", "\n", "Number of divergences: 0\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "LJ2a6Epnsuqf" }, "source": [ "## Posterior Plots" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 331 }, "id": "bPpET7b6suqg", "outputId": "96bb9246-6358-4fa3-e4c6-ea104cb1635f" }, "source": [ "print(ind)\n", "print(reg_data.date[ind])\n", "\n", "sns.distplot(weight_1_post, \n", " kde_kws = {\"label\": \"Weight posterior before CP\"}, \n", " color = \"red\",\n", " norm_hist = True,\n", " kde = True)\n", "plt.axvline(x = w1_[\"mean\"], linestyle = '--',label = \"Mean weight before CP\" ,\n", " color = \"red\")\n", "\n", "sns.distplot(weight_2_post, \n", " kde_kws = {\"label\": \"Weight posterior after CP\"}, \n", " color = \"teal\",\n", " norm_hist = True,\n", " kde = True)\n", "plt.axvline(x = w2_[\"mean\"], linestyle = '--',label = \"Mean weight after CP\" ,\n", " color = \"teal\")\n", "\n", "legend = plt.legend(loc='upper right')\n", "legend.get_frame().set_alpha(1)\n", "sns.set_style(\"ticks\")\n", "plt.tight_layout()\n", "sns.despine()\n", "plt.savefig('/content/sample_data/china_weights.pdf')\n" ], "execution_count": 43, "outputs": [ { "output_type": "stream", "text": [ "17\n", "2020-02-08 00:00:00\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "cell_type": "code", "metadata": { "id": "A4glDCQRO9Ss" }, "source": [ "predictions['date'] = pd.to_datetime(reg_data.date)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "DAvMkdZagKu2" }, "source": [ "# Final plot" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 395 }, "id": "gejnd0Qxsuq1", "outputId": "90859b2d-0f2d-46ae-bd01-a298cab66b1c" }, "source": [ "start_date_ = str(reg_data.date[0]).split(' ')[0]\n", "change_date_ = str(reg_data.date[ind]).split(' ')[0]\n", "print(\"Date of change for {}: {}\".format(country_, change_date_))\n", "import seaborn as sns\n", "\n", "# plot data:\n", "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))\n", "ax = [ax]\n", "# log regression model\n", "ax[0].scatter(y = np.exp(y_data[:ind]), x = x_data[:ind], s = 15);\n", "ax[0].scatter(y = np.exp(y_data[ind:]), x = x_data[ind:], s = 15, color = \"red\");\n", "\n", "ax[0].plot(predictions[\"days_since_start\"],\n", " np.exp(predictions[\"y_mean\"]), \n", " color = \"green\",\n", " label = \"Fitted line by MCMC-NUTS model\") \n", "ax[0].axvline(1, \n", " linestyle = '--', linewidth = 1.5,\n", " label = \"Date of Lockdown: Jan 22, 2020\" ,\n", " color = \"red\")\n", "\n", "ax[0].axvline(ind, \n", " linestyle = '--', linewidth = 1.5,\n", " label = \"Date of Change: Feb 8, 2020\",\n", " color = \"black\")\n", "\n", "ax[0].fill_between(predictions[\"days_since_start\"], \n", " np.exp(predictions[\"y_perc_5\"]), \n", " np.exp(predictions[\"y_perc_95\"]), \n", " alpha = 0.25,\n", " label = \"90% CI of predictions\",\n", " color = \"teal\");\n", "ax[0].fill_betweenx([0, 1], \n", " tau_[\"5%\"] * len(x_data), \n", " tau_[\"95%\"] * len(x_data), \n", " alpha = 0.25,\n", " label = \"90% CI of changing point\",\n", " color = \"lightcoral\",\n", " transform=ax[0].get_xaxis_transform());\n", "ax[0].set(ylabel = \"Total Cases\",)\n", " # xlabel = \"Days since %s\" % start_date_, \n", " # title = \"Confirmed Cases in China\") /\n", "ax[0].legend(loc = \"lower right\", fontsize=12.8)\n", "ax[0].set_ylim([100,150000])\n", "ax[0].xaxis.get_label().set_fontsize(16)\n", "ax[0].yaxis.get_label().set_fontsize(16)\n", "ax[0].title.set_fontsize(20)\n", "ax[0].tick_params(labelsize=16)\n", "\n", "plt.xticks(ticks=[1,17,34,51], labels=[\"Jan 22\",\n", " \"Feb 8\",\n", " \"Feb 25\",\n", " \"Mar 13\"], fontsize=15)\n", "ax[0].set_yscale('log')\n", "plt.setp(ax[0].get_xticklabels(), rotation=0, horizontalalignment='center')\n", "print(reg_data.columns)\n", "myFmt = mdates.DateFormatter('%m-%d')\n", "sns.set_style(\"ticks\")\n", "sns.despine()\n", "ax[0].figure.savefig('/content/sample_data/china_cp.pdf')\n" ], "execution_count": 44, "outputs": [ { "output_type": "stream", "text": [ "Date of change for China (Before March 16th): 2020-02-08\n", "Index(['country', 'date', 'confirmed', 'deaths', 'recovered', 'date_only',\n", " 'daily_confirmed', 'moving_avg', 'days_since_start'],\n", " dtype='object')\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] } ] }