{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Compartmental_model.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "AJX1jgLnX70s" }, "source": [ "# Epidemiology model\n", "\n", "https://nbviewer.jupyter.org/github/pyro-ppl/pyro/blob/sir-tutorial-ii/tutorial/source/epi_regional.ipynb?fbclid=IwAR3Gv8tLuiEjOmZh7-NQUa_ggm_QUqtSc5TxRZ0_pSxVA7Y3lWWzSFGKjrA \n" ] }, { "cell_type": "code", "metadata": { "id": "z6UoAzRe1pMh" }, "source": [ "!git clone https://github.com/pyro-ppl/pyro.git" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ryMxWMvbD8Nc" }, "source": [ "%cd /content/pyro\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "w8MT-jR48mLX" }, "source": [ "!pip install .[extras]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "7AUI1jmXcX4u" }, "source": [ "import os\n", "import logging\n", "import urllib.request\n", "from collections import OrderedDict\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist\n", "from pyro.ops.tensor_utils import convolve\n", "\n", "%matplotlib inline\n", "pyro.enable_validation(True) \n", "torch.set_default_dtype(torch.double) \n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hzpnb36feNgS" }, "source": [ " ## Model without Policies\n", " " ] }, { "cell_type": "code", "metadata": { "id": "bsaJcjx6xVLo" }, "source": [ "class CovidModel(CompartmentalModel):\n", " def __init__(self, population, new_cases, new_recovered, new_deaths):\n", " '''\n", " population (int) – Total population = S + E + I + R.\n", " '''\n", " assert len(new_cases) == len(new_recovered) == len(new_deaths)\n", "\n", " compartments = (\"S\", \"E\", \"I\", \"D\") # R is implicit.\n", " duration = len(new_cases)\n", " super().__init__(compartments, duration, population)\n", "\n", " self.new_cases = new_cases\n", " self.new_deaths = new_deaths\n", " self.new_recovered = new_recovered\n", " \n", "\n", " def global_model(self):\n", " tau_i = pyro.sample(\"rec_time\", dist.Normal(15.0, 3.0))\n", " tau_e = pyro.sample(\"incub_time\", dist.Normal(5.0, 1.0))\n", " # R0 = pyro.sample(\"R0\", dist.LogNormal(0., 1.))\n", " R0 = pyro.sample(\"R0\", dist.Normal(2.5, 0.5))\n", " rho = pyro.sample(\"rho\", dist.Beta(10, 10)) # About 50% response rate.\n", " mort_rate = pyro.sample(\"mort_rate\", dist.Beta(2, 50)) # About 2% mortality rate.\n", " rec_rate = pyro.sample(\"rec_rate\",dist.Beta(10, 10)) # About 50% recovery rate.\n", " return R0, tau_e, tau_i, rho, mort_rate, rec_rate\n", "\n", " def initialize(self, params):\n", " # Start with a single infection.\n", " return {\"S\": self.population - 1, \"E\": 0, \"I\": 1, \"D\": 0}\n", "\n", " def transition(self, params, state, t):\n", " R0, tau_e, tau_i, rho, mort_rate, rec_rate = params\n", "\n", " # Sample flows between compartments.\n", " S2E = pyro.sample(\"S2E_{}\".format(t),\n", " infection_dist(individual_rate=R0 / tau_i,\n", " num_susceptible=state[\"S\"],\n", " num_infectious=state[\"I\"],\n", " population=self.population))\n", " E2I = pyro.sample(\"E2I_{}\".format(t),\n", " binomial_dist(state[\"E\"], 1 / tau_e )) \n", " I2R = pyro.sample(\"I2R_{}\".format(t),\n", " binomial_dist(state[\"I\"], 1 / tau_i))\n", " I2D = pyro.sample(\"I2D_{}\".format(t),\n", " binomial_dist(state[\"I\"], mort_rate / tau_i))\n", "\n", " # Update compartments with flows.\n", " state[\"S\"] = state[\"S\"] - S2E \n", " state[\"E\"] = state[\"E\"] + S2E - E2I\n", " state[\"I\"] = state[\"I\"] + E2I - I2R - I2D\n", " state[\"D\"] = state[\"D\"] + I2D\n", "\n", " # Condition on observations.\n", " t_is_observed = isinstance(t, slice) or t < self.duration\n", " pyro.sample(\"new_cases_{}\".format(t),\n", " binomial_dist(S2E, rho),\n", " obs=self.new_cases[t] if t_is_observed else None)\n", " pyro.sample(\"new_deaths_{}\".format(t),\n", " binomial_dist(I2D, 1),\n", " obs=self.new_deaths[t] if t_is_observed else None)\n", " pyro.sample(\"new_recovered_{}\".format(t),\n", " binomial_dist(I2R, rho),\n", " obs=self.new_recovered[t] if t_is_observed else None)\n", " \n", " def compute_flows(self, prev, curr, t):\n", " S2E = prev[\"S\"] - curr[\"S\"] # S can only go to E.\n", " I2D = curr[\"D\"] - prev[\"D\"] # D can only have come from I.\n", " # We deduce the remaining flows by conservation of mass:\n", " # curr - prev = inflows - outflows\n", " E2I = prev[\"E\"] - curr[\"E\"] + S2E\n", " I2R = prev[\"I\"] - curr[\"I\"] + E2I - I2D\n", " return {\n", " \"S2E_{}\".format(t): S2E,\n", " \"E2I_{}\".format(t): E2I,\n", " \"I2D_{}\".format(t): I2D,\n", " \"I2R_{}\".format(t): I2R,\n", " }" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "kgKCZvRfMi_3" }, "source": [ "## Create Country" ] }, { "cell_type": "code", "metadata": { "id": "koX5yGHrsuib" }, "source": [ "# function to make the time series of confirmed and daily confirmed cases for a specific country\n", "def create_country (country, start_date, end_date, state = False) : \n", "\n", " url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/full_grouped.csv'\n", " data = pd.read_csv(url)\n", "\n", " data.Date = pd.to_datetime(data.Date)\n", "\n", " if state :\n", " df = data.loc[data[\"Province/State\"] == country, [\"Province/State\", \"Date\", \"Confirmed\", \"Deaths\", \"Recovered\", \"Active\", \"New cases\", \"New deaths\", \"New recovered\"]]\n", " else : \n", " df = data.loc[data[\"Country/Region\"] == country, [\"Country/Region\", \"Date\", \"Confirmed\", \"Deaths\", \"Recovered\", \"Active\", \"New cases\", \"New deaths\", \"New recovered\"]]\n", " df.columns = [\"country\", \"date\", \"confirmed\", \"deaths\", \"recovered\", \"active\", \"new_cases\", \"new_deaths\", \"new_recovered\"]\n", "\n", " # group by country and date\n", " df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered',\"active\", \"new_cases\", \"new_deaths\", \"new_recovered\"].sum().reset_index()\n", "\n", " # convert date string to datetime\n", " df.date = pd.to_datetime(df.date)\n", " df = df.sort_values(by = \"date\")\n", " df = df[df.date >= start_date]\n", " df = df[df.date <= end_date]\n", "\n", " active = df['active'].tolist()\n", " recovered = df['recovered'].tolist()\n", " deaths = df['deaths'].tolist()\n", " new_cases = df['new_cases'].tolist()\n", " new_recovered = df['new_recovered'].tolist()\n", " new_deaths = df['new_deaths'].tolist()\n", " \n", " active = torch.tensor(list(map(float, active))).view(len(active),1) \n", " recovered = torch.tensor(list(map(float, recovered))).view(len(recovered),1) \n", " deaths = torch.tensor(list(map(float, deaths))).view(len(deaths),1) \n", " new_cases = torch.tensor(list(map(float, new_cases))).view(len(new_cases),1) \n", " new_recovered = torch.tensor(list(map(float, new_recovered))).view(len(new_recovered),1) \n", " new_deaths = torch.tensor(list(map(float, new_deaths))).view(len(new_deaths),1) \n", "\n", "\n", " return_data = {\n", " 'active':active,\n", " 'recovered':recovered,\n", " 'deaths':deaths,\n", " 'new_cases':new_cases,\n", " 'new_recovered': new_recovered,\n", " 'new_deaths':new_deaths }\n", " \n", " return return_data\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "nxRxbNr8zt3O" }, "source": [ "## Get data for countries\n" ] }, { "cell_type": "code", "metadata": { "id": "MzNNysenCiWz" }, "source": [ "# Parameters\n", "country = \"Japan\"\n", "start_date = \"2020-02-01\" \n", "end_date = \"2020-04-01\"\n", "population = 126500000\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "V4O6XUbM9Ff3" }, "source": [ "data = create_country(country, start_date, end_date)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "XTUPDWD9e_9o" }, "source": [ "##Train the model using MCMC.\n", "\n" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bX4Jq5qmC6Ke", "outputId": "dfce3586-740f-44e2-9324-25406d50cf8f" }, "source": [ "%%time\n", "model = CovidModel(population, data[\"new_cases\"], data[\"new_recovered\"], data[\"new_deaths\"] )\n", "mcmc = model.fit_mcmc(num_samples=500, warmup_steps = 200)\n", "mcmc.summary()" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "INFO \t Running inference...\n", "Warmup: 0%| | 0/700 [00:00, ?it/s]INFO \t Heuristic init: R0=1.96, incub_time=4.5, mort_rate=0.00865, rec_rate=0.55, rec_time=16.7, rho=0.176\n", "Sample: 