# Epidemiology model

https://nbviewer.jupyter.org/github/pyro-ppl/pyro/blob/sir-tutorial-ii/tutorial/source/epi_regional.ipynb?fbclid=IwAR3Gv8tLuiEjOmZh7-NQUa_ggm_QUqtSc5TxRZ0_pSxVA7Y3lWWzSFGKjrA 


In [None]:
!git clone https://github.com/pyro-ppl/pyro.git

fatal: destination path 'pyro' already exists and is not an empty directory.


In [None]:
%cd /content/pyro


/content/pyro


In [None]:
!pip install .[extras]

In [None]:
import os
import logging
import urllib.request
from collections import OrderedDict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import pyro
import pyro.distributions as dist
from pyro.contrib.epidemiology import CompartmentalModel, binomial_dist, infection_dist
from pyro.ops.tensor_utils import convolve

%matplotlib inline
pyro.enable_validation(True) 
torch.set_default_dtype(torch.double) 


 ## Model without Policies
 

In [None]:
class CovidModel(CompartmentalModel):
 def __init__(self, population, new_cases, new_recovered, new_deaths):
 '''
 population (int) – Total population = S + E + I + R.
 '''
 assert len(new_cases) == len(new_recovered) == len(new_deaths)

 compartments = ("S", "E", "I", "D") # R is implicit.
 duration = len(new_cases)
 super().__init__(compartments, duration, population)

 self.new_cases = new_cases
 self.new_deaths = new_deaths
 self.new_recovered = new_recovered
 

 def global_model(self):
 tau_i = pyro.sample("rec_time", dist.Normal(15.0, 3.0))
 tau_e = pyro.sample("incub_time", dist.Normal(5.0, 1.0))
 # R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
 R0 = pyro.sample("R0", dist.Normal(2.5, 0.5))
 rho = pyro.sample("rho", dist.Beta(10, 10)) # About 50% response rate.
 mort_rate = pyro.sample("mort_rate", dist.Beta(2, 50)) # About 2% mortality rate.
 rec_rate = pyro.sample("rec_rate",dist.Beta(10, 10)) # About 50% recovery rate.
 return R0, tau_e, tau_i, rho, mort_rate, rec_rate

 def initialize(self, params):
 # Start with a single infection.
 return {"S": self.population - 1, "E": 0, "I": 1, "D": 0}

 def transition(self, params, state, t):
 R0, tau_e, tau_i, rho, mort_rate, rec_rate = params

 # Sample flows between compartments.
 S2E = pyro.sample("S2E_{}".format(t),
 infection_dist(individual_rate=R0 / tau_i,
 num_susceptible=state["S"],
 num_infectious=state["I"],
 population=self.population))
 E2I = pyro.sample("E2I_{}".format(t),
 binomial_dist(state["E"], 1 / tau_e )) 
 I2R = pyro.sample("I2R_{}".format(t),
 binomial_dist(state["I"], 1 / tau_i))
 I2D = pyro.sample("I2D_{}".format(t),
 binomial_dist(state["I"], mort_rate / tau_i))

 # Update compartments with flows.
 state["S"] = state["S"] - S2E 
 state["E"] = state["E"] + S2E - E2I
 state["I"] = state["I"] + E2I - I2R - I2D
 state["D"] = state["D"] + I2D

 # Condition on observations.
 t_is_observed = isinstance(t, slice) or t < self.duration
 pyro.sample("new_cases_{}".format(t),
 binomial_dist(S2E, rho),
 obs=self.new_cases[t] if t_is_observed else None)
 pyro.sample("new_deaths_{}".format(t),
 binomial_dist(I2D, 1),
 obs=self.new_deaths[t] if t_is_observed else None)
 pyro.sample("new_recovered_{}".format(t),
 binomial_dist(I2R, rho),
 obs=self.new_recovered[t] if t_is_observed else None)
 
 def compute_flows(self, prev, curr, t):
 S2E = prev["S"] - curr["S"] # S can only go to E.
 I2D = curr["D"] - prev["D"] # D can only have come from I.
 # We deduce the remaining flows by conservation of mass:
 # curr - prev = inflows - outflows
 E2I = prev["E"] - curr["E"] + S2E
 I2R = prev["I"] - curr["I"] + E2I - I2D
 return {
 "S2E_{}".format(t): S2E,
 "E2I_{}".format(t): E2I,
 "I2D_{}".format(t): I2D,
 "I2R_{}".format(t): I2R,
 }

## Create Country

In [None]:
# function to make the time series of confirmed and daily confirmed cases for a specific country
def create_country (country, start_date, end_date, state = False) : 

 url = 'https://raw.githubusercontent.com/assemzh/ProbProg-COVID-19/master/full_grouped.csv'
 data = pd.read_csv(url)

 data.Date = pd.to_datetime(data.Date)

 if state :
 df = data.loc[data["Province/State"] == country, ["Province/State", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]]
 else : 
 df = data.loc[data["Country/Region"] == country, ["Country/Region", "Date", "Confirmed", "Deaths", "Recovered", "Active", "New cases", "New deaths", "New recovered"]]
 df.columns = ["country", "date", "confirmed", "deaths", "recovered", "active", "new_cases", "new_deaths", "new_recovered"]

 # group by country and date
 df = df.groupby(['country','date'])['confirmed', 'deaths', 'recovered',"active", "new_cases", "new_deaths", "new_recovered"].sum().reset_index()

 # convert date string to datetime
 df.date = pd.to_datetime(df.date)
 df = df.sort_values(by = "date")
 df = df[df.date >= start_date]
 df = df[df.date <= end_date]

 active = df['active'].tolist()
 recovered = df['recovered'].tolist()
 deaths = df['deaths'].tolist()
 new_cases = df['new_cases'].tolist()
 new_recovered = df['new_recovered'].tolist()
 new_deaths = df['new_deaths'].tolist()
 
 active = torch.tensor(list(map(float, active))).view(len(active),1) 
 recovered = torch.tensor(list(map(float, recovered))).view(len(recovered),1) 
 deaths = torch.tensor(list(map(float, deaths))).view(len(deaths),1) 
 new_cases = torch.tensor(list(map(float, new_cases))).view(len(new_cases),1) 
 new_recovered = torch.tensor(list(map(float, new_recovered))).view(len(new_recovered),1) 
 new_deaths = torch.tensor(list(map(float, new_deaths))).view(len(new_deaths),1) 


 return_data = {
 'active':active,
 'recovered':recovered,
 'deaths':deaths,
 'new_cases':new_cases,
 'new_recovered': new_recovered,
 'new_deaths':new_deaths }
 
 return return_data


## Get data for countries


In [None]:
Japan = create_country("Japan", start_date = "2020-02-01", end_date = "2020-04-01")
Sweden = create_country("Sweden", start_date = "2020-02-01", end_date = "2020-04-01")


 app.launch_new_instance()


##Train the model using MCMC.



In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
pyro.set_rng_seed(20210521)
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=2.64, incub_time=3.56, mort_rate=0.0147, rec_rate=0.675, rec_time=8.59, rho=0.178
Sample: 100%|██████████| 700/700 [00:58, 11.89it/s, step size=2.11e-03, acc. prob=0.731]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 2.60 0.01 2.59 2.59 2.62 4.13 1.43
 auxiliary[0,0] 126499886.77 1.08 126499886.19 126499886.15 126499889.01 4.25 1.36
 auxiliary[0,1] 126499769.19 2.23 126499768.03 126499767.82 126499773.74 4.36 1.37
 auxiliary[0,2] 126499647.83 2.74 126499646.40 126499646.10 126499653.37 4.23 1.41
 auxiliary[0,3] 126499535.96 6.14 126499533.05 126499531.95 126499548.16 4.36 1.40
 auxiliary[0,4] 126499424.13 8.69 126499420.06 126499418.42 126499441.69 4.34 1.40
 auxiliary[0,5] 126499314.64 10.66 126499309.65 126499307.31 126499336.33 4.29 1.41
 auxiliary[0,6] 126499206.81 14.79 126499199.98 126499196.97 126499236.67 4.35 1.41
 auxiliary[0,7] 126499098.07 17.60 126499089.71 126499086.30 126499134.19 4.29 1.41
 auxiliary[0,8] 126498987.18 22.51 126498976.38 126498972.55 126499034.20 4.29 1.41
 auxiliary[0,9] 126498871.47 25.33 126498859.33 126498854.66 126498924.56 4.20 1.42
auxiliary[0,10] 126498749.30 28.85 126498735.30 126498728.94 126498808.63 4.11 1.45
a

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.74, incub_time=5.76, mort_rate=0.0166, rec_rate=0.343, rec_time=14.8, rho=0.191
Sample: 100%|██████████| 700/700 [03:17, 3.55it/s, step size=1.52e-05, acc. prob=0.999]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 1.74 0.00 1.74 1.74 1.74 3.20 1.61
 auxiliary[0,0] 126499865.75 0.00 126499865.75 126499865.74 126499865.75 4.63 1.45
 auxiliary[0,1] 126499752.70 0.14 126499752.73 126499752.50 126499752.87 3.01 1.99
 auxiliary[0,2] 126499637.31 0.30 126499637.23 126499636.92 126499637.75 2.94 2.03
 auxiliary[0,3] 126499516.28 0.06 126499516.28 126499516.20 126499516.36 2.51 2.60
 auxiliary[0,4] 126499394.94 0.01 126499394.94 126499394.92 126499394.97 3.27 2.07
 auxiliary[0,5] 126499290.14 0.01 126499290.14 126499290.12 126499290.15 7.21 1.00
 auxiliary[0,6] 126499183.37 0.03 126499183.36 126499183.32 126499183.42 2.37 3.46
 auxiliary[0,7] 126499071.08 0.04 126499071.08 126499071.03 126499071.14 2.65 2.38
 auxiliary[0,8] 126498947.86 0.01 126498947.86 126498947.84 126498947.87 5.63 1.26
 auxiliary[0,9] 126498832.96 0.03 126498832.96 126498832.92 126498833.00 4.10 2.01
auxiliary[0,10] 126498734.89 0.03 126498734.89 126498734.85 126498734.93 4.61 1.00
auxilia

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.92, incub_time=6.15, mort_rate=0.019, rec_rate=0.571, rec_time=14.9, rho=0.185
Sample: 100%|██████████| 700/700 [03:14, 3.60it/s, step size=8.57e-05, acc. prob=0.897]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 1.92 0.00 1.92 1.92 1.92 2.54 2.62
 auxiliary[0,0] 126499876.90 0.82 126499876.77 126499875.87 126499877.80 2.34 4.23
 auxiliary[0,1] 126499777.08 0.42 126499777.28 126499776.40 126499777.53 2.55 2.58
 auxiliary[0,2] 126499662.89 0.20 126499662.94 126499662.55 126499663.17 2.74 2.25
 auxiliary[0,3] 126499555.96 0.22 126499556.03 126499555.65 126499556.29 2.48 2.79
 auxiliary[0,4] 126499440.88 0.19 126499440.94 126499440.49 126499441.14 2.89 2.15
 auxiliary[0,5] 126499322.65 0.26 126499322.74 126499322.23 126499322.99 2.78 2.09
 auxiliary[0,6] 126499198.40 0.29 126499198.55 126499197.83 126499198.70 2.70 2.48
 auxiliary[0,7] 126499088.87 0.31 126499088.79 126499088.51 126499089.34 2.38 3.08
 auxiliary[0,8] 126498989.71 0.61 126498989.85 126498988.75 126498990.41 2.49 2.64
 auxiliary[0,9] 126498884.78 0.51 126498884.95 126498884.01 126498885.35 2.58 2.45
auxiliary[0,10] 126498769.17 0.41 126498769.29 126498768.53 126498769.73 2.67 2.30
auxilia

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.99, incub_time=3.9, mort_rate=0.106, rec_rate=0.607, rec_time=16, rho=0.166
Sample: 100%|██████████| 700/700 [03:09, 3.70it/s, step size=4.61e-03, acc. prob=0.926]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 1.86 0.08 1.85 1.73 1.97 2.50 2.66
 auxiliary[0,0] 126499854.58 1.51 126499855.07 126499851.89 126499856.13 3.32 1.69
 auxiliary[0,1] 126499717.84 5.86 126499716.51 126499710.19 126499725.92 2.58 2.41
 auxiliary[0,2] 126499576.41 10.26 126499576.34 126499560.94 126499591.70 2.51 2.65
 auxiliary[0,3] 126499432.41 18.40 126499431.39 126499407.07 126499460.12 2.57 2.48
 auxiliary[0,4] 126499285.99 29.02 126499281.96 126499247.42 126499330.47 2.65 2.34
 auxiliary[0,5] 126499138.43 40.51 126499130.06 126499087.96 126499202.12 2.68 2.28
 auxiliary[0,6] 126498989.77 53.18 126498974.02 126498922.86 126499074.76 2.69 2.26
 auxiliary[0,7] 126498841.09 66.55 126498820.29 126498762.06 126498953.46 2.73 2.21
 auxiliary[0,8] 126498692.21 80.89 126498665.32 126498594.94 126498825.49 2.77 2.17
 auxiliary[0,9] 126498542.69 95.65 126498511.57 126498429.29 126498702.61 2.80 2.14
auxiliary[0,10] 126498391.67 110.93 126498355.10 126498259.08 126498576.75 2.83 2.

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.75, incub_time=5.27, mort_rate=0.0177, rec_rate=0.486, rec_time=9.02, rho=0.189
Sample: 100%|██████████| 700/700 [03:23, 3.43it/s, step size=3.40e-03, acc. prob=0.900]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 1.58 0.05 1.57 1.51 1.65 2.50 2.67
 auxiliary[0,0] 126499861.04 0.32 126499861.05 126499860.57 126499861.55 4.43 1.23
 auxiliary[0,1] 126499725.56 3.72 126499725.38 126499720.45 126499731.75 2.60 2.41
 auxiliary[0,2] 126499582.23 7.36 126499582.32 126499570.46 126499593.10 2.65 2.35
 auxiliary[0,3] 126499430.67 14.02 126499428.50 126499408.16 126499451.40 2.63 2.39
 auxiliary[0,4] 126499274.14 20.12 126499269.55 126499244.83 126499306.75 2.86 2.08
 auxiliary[0,5] 126499111.52 25.77 126499104.47 126499072.74 126499154.02 2.88 2.11
 auxiliary[0,6] 126498946.10 32.21 126498937.63 126498896.77 126498995.47 2.86 2.11
 auxiliary[0,7] 126498777.60 39.33 126498767.76 126498719.14 126498838.36 2.83 2.14
 auxiliary[0,8] 126498608.27 45.02 126498597.00 126498540.61 126498677.24 2.81 2.16
 auxiliary[0,9] 126498437.75 51.00 126498424.66 126498360.03 126498515.88 2.78 2.21
auxiliary[0,10] 126498266.56 56.62 126498251.56 126498179.86 126498356.23 2.77 2.21

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=2.23, incub_time=5.71, mort_rate=0.0162, rec_rate=0.545, rec_time=16.2, rho=0.167
Sample: 100%|██████████| 700/700 [01:07, 10.34it/s, step size=4.11e-04, acc. prob=0.746]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 2.23 0.00 2.23 2.23 2.23 4.53 2.19
 auxiliary[0,0] 126499886.81 0.05 126499886.80 126499886.75 126499886.89 2.39 3.31
 auxiliary[0,1] 126499789.68 0.04 126499789.66 126499789.63 126499789.74 2.56 2.60
 auxiliary[0,2] 126499692.94 0.18 126499692.92 126499692.70 126499693.21 2.48 3.05
 auxiliary[0,3] 126499590.75 0.26 126499590.65 126499590.45 126499591.11 2.45 3.53
 auxiliary[0,4] 126499488.94 0.31 126499488.75 126499488.54 126499489.41 2.52 2.90
 auxiliary[0,5] 126499382.64 0.31 126499382.64 126499382.18 126499383.09 2.45 2.80
 auxiliary[0,6] 126499281.13 0.52 126499281.00 126499280.43 126499281.96 2.52 2.69
 auxiliary[0,7] 126499158.90 0.30 126499158.85 126499158.36 126499159.32 2.67 2.46
 auxiliary[0,8] 126499050.99 0.52 126499050.83 126499050.43 126499051.77 2.43 3.73
 auxiliary[0,9] 126498936.43 0.52 126498936.27 126498935.81 126498937.31 2.50 3.02
auxiliary[0,10] 126498839.71 0.72 126498839.34 126498838.97 126498841.01 2.67 2.41
auxilia

In [None]:
%%time
Japan_model = CovidModel(126500000, Japan["new_cases"], Japan["new_recovered"], Japan["new_deaths"] )
Japan_mcmc = Japan_model.fit_mcmc(num_samples=500, warmup_steps = 200)
Japan_mcmc.summary()

INFO 	 Running inference...
Warmup: 0%| | 0/700 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.82, incub_time=6.18, mort_rate=0.0269, rec_rate=0.522, rec_time=21.3, rho=0.179
Sample: 100%|██████████| 700/700 [03:14, 3.59it/s, step size=2.83e-03, acc. prob=0.902]



 mean std median 5.0% 95.0% n_eff r_hat
 R0 1.79 0.01 1.80 1.77 1.81 2.61 2.43
 auxiliary[0,0] 126499875.14 1.63 126499875.67 126499872.40 126499877.39 3.23 1.59
 auxiliary[0,1] 126499754.75 6.21 126499753.75 126499746.75 126499764.75 2.55 2.54
 auxiliary[0,2] 126499625.88 10.52 126499624.23 126499611.53 126499643.13 2.50 2.74
 auxiliary[0,3] 126499489.51 14.12 126499487.21 126499470.48 126499513.28 2.54 2.59
 auxiliary[0,4] 126499348.50 17.42 126499345.65 126499322.72 126499376.29 2.57 2.53
 auxiliary[0,5] 126499203.81 21.27 126499199.38 126499173.62 126499236.91 2.52 2.61
 auxiliary[0,6] 126499056.45 25.10 126499050.11 126499018.49 126499094.75 2.56 2.54
 auxiliary[0,7] 126498907.65 29.93 126498902.53 126498862.83 126498954.15 2.58 2.47
 auxiliary[0,8] 126498757.91 34.95 126498751.36 126498704.63 126498810.43 2.60 2.46
 auxiliary[0,9] 126498607.76 40.82 126498600.37 126498549.00 126498670.83 2.58 2.51
auxiliary[0,10] 126498458.77 46.22 126498451.22 126498393.21 126498528.66 2.57 2.5