{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Suppose 2 R0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pyro.infer.mcmc import MCMC, NUTS\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "\n", "import pyro\n", "import pyro.infer\n", "import pyro.optim\n", "import pyro.distributions as dist\n", "from torch.distributions import constraints\n", "\n", "import seaborn" ] }, { "cell_type": "code", "execution_count": 197, "metadata": {}, "outputs": [], "source": [ "# Defining parameters as stochastic function\n", "\n", "# deaths in hospitals / total deaths\n", "def frac_dh():\n", " return torch.tensor(3470. / 7594.)\n", "\n", "# fraction of hospitalized \n", "def hh():\n", " return torch.tensor(0.05)\n", "\n", "# inverse recovery time\n", "def gamma():\n", " return torch.tensor(1. / 12.4)\n", "\n", "# inverse incubation time \n", "def epsilon():\n", " return torch.tensor(1. / 5.2)\n", "\n", "# fatality rate in icu \n", "def dea():\n", " return torch.tensor(.5)\n", "\n", "# population size\n", "def n0():\n", " return torch.tensor(11000000.)\n", "\n", "# population en MR/MRS + personnel soignant\n", "def n0_MRS():\n", " return torch.tensor(400000.)\n", "\n", "# e0 = i0 * factor\n", "def e0_factor():\n", " return torch.tensor(37.)\n", "\n", "# e0_MRS = i0_MRS * factor\n", "def e0_MRS_factor():\n", " return torch.tensor(20.)\n", "\n", "# size of the window for fitting Re's\n", "def window():\n", " return torch.tensor(6.)\n", "\n", "def i0():\n", " #return pyro.sample(\"i0\", dist.Poisson(10))\n", " return torch.tensor(3.)\n", "\n", "def r0_model(r0_mean):\n", " return r0\n", "\n", "def drea():\n", " return dea()/ 5. \n", " \n", "def rrea():\n", " return (1 - dea())/20.\n", " \n", "def hospi():\n", " return torch.tensor(0.0)\n", "\n", "def gg():\n", " return torch.tensor(.75)" ] }, { "cell_type": "code", "execution_count": 198, "metadata": {}, "outputs": [], "source": [ "def SEIR(r0): \n", " # Initial conditions\n", " n = [n0()] # Population totale\n", " i = [i0()] # Symptomatiques\n", " e = [i[0] * e0_factor()] # Asymptomatiques\n", " h = [torch.tensor(.0)] # Lits occupés hopital\n", " l = [torch.tensor(.0)] # Lits occupés USI\n", " r = [torch.tensor(.0)] # Immunisés\n", " m = [torch.tensor(.0)] # Morts (totaux)\n", " s = [n[-1] - e[-1] - i[-1] - r[-1]] # Sains\n", " \n", " # Simulate forward\n", " n_days = len(r0)\n", " \n", " hospi = 0.\n", " for day in range(n_days):\n", " lam = gamma() * r0[day]\n", " \n", " if day == 14:\n", " hospi = hh() / 7\n", " \n", " ds = -lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1]\n", " de = lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] - epsilon() * e[-1]\n", " di = epsilon() * e[-1] - gamma() * i[-1] - hospi * i[-1]\n", " dh = hospi * i[-1] - gg() * h[-1] / 7 - (1 - gg()) * h[-1] / (4. + 2 * torch.tanh((l[-1]-500.)/300.))\n", " dl = (1 - gg()) * h[-1] / (4 + 2 * torch.tanh((l[-1]-500)/300)) - drea() * l[-1] - rrea() * l[-1]\n", " dr = gamma() * i[-1] + rrea() * l[-1] + gg() * h[-1] / 7\n", " dm = drea() * l[-1] \n", " \n", " s.append(s[-1] + ds)\n", " e.append(e[-1] + de)\n", " i.append(i[-1] + di)\n", " h.append(h[-1] + dh)\n", " l.append(l[-1] + dl)\n", " if l[-1] > 1895:\n", " dm = dm + (l[-1] - 1895)\n", " l[-1] = torch.tensor(1895.)\n", " r.append(r[-1] + dr)\n", " m.append(m[-1] + dm)\n", " n.append(s[-1] + e[-1] + i[-1] + h[-1] + l[-1] + r[-1])\n", " return s, e, i, h, l, m, r\n", "def SEIR_MRS(r0_mrs, n_futures=0, window=6): \n", " # Smoothen and extend R0s\n", " \n", " # Initial conditions\n", " alpha = 0.15 / 10\n", " lam = gamma() * 4.3\n", " \n", " n = [n0_MRS()]\n", " i = [torch.tensor(1)]\n", " e = [i[-1] * e0_MRS_factor()]\n", " r = [torch.tensor(0.0)]\n", " s = [n[-1] - e[-1] - i[-1] - r[-1]]\n", " m = [torch.tensor(0.0)]\n", " \n", " # Simulate forward\n", " n_days = len(r0_mrs)\n", " \n", " for day in range(n_days):\n", " lam = gamma() * r0_mrs[day]\n", " \n", " ds = -lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1]\n", " de = lam * (i[-1] / 2 + e[-1]) * s[-1] / n[-1] - epsilon() * e[-1]\n", " di = epsilon() * e[-1] - (gamma() + alpha) * i[-1]\n", " dr = gamma() * i[-1]\n", " dm = alpha * i[-1]\n", " \n", " s.append(s[-1] + ds)\n", " e.append(e[-1] + de)\n", " i.append(i[-1] + di)\n", " r.append(r[-1] + dr)\n", " m.append(m[-1] + dm)\n", " n.append(s[-1] + e[-1] + i[-1] + r[-1])\n", " \n", " return s, e, i, m, r" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from covid19be import load_data\n", "data_df = load_data()\n", "n_days = len(data_df)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(source: [wikipedia covid-19 Belgique](https://fr.wikipedia.org/wiki/Pand%C3%A9mie_de_Covid-19_en_Belgique))\n", "- 11 Mars arrêt des visites en mrs. (en Wallonie uniquement) (13)\n", "- 16 Mars plus de cours. (18)\n", "- 18 Mars confinement. (20)\n", "- 20 avril magasin de bricolage et pépinieriste ouvrent à nouveau. (22)\n", "\n", "Nos données débutent au 28 février" ] }, { "cell_type": "code", "execution_count": 200, "metadata": {}, "outputs": [], "source": [ "date_r0_switch = 20\n", "date_r0_switch_mrs = 13\n", "nb_dirty_data = 4\n", "n_useful_days = n_days - nb_dirty_data" ] }, { "cell_type": "code", "execution_count": 205, "metadata": {}, "outputs": [], "source": [ "mrs = True\n", "noiser = lambda mu: dist.ZeroInflatedPoisson(torch.tensor(.0001), mu + 1)" ] }, { "cell_type": "code", "execution_count": 216, "metadata": {}, "outputs": [], "source": [ "def SEIR_full_model_bis(ret_data=False): \n", " n_futures=0\n", " # Simulate\n", " r0_1 = pyro.sample(\"r0_1\", dist.Uniform(torch.tensor(0.), torch.tensor(6.)))\n", " r0_2 = pyro.sample(\"r0_2\", dist.Uniform(torch.tensor(0.), torch.tensor(3.)))\n", " # R0 fluctuates around its means that are given by r0_1 and r0_2\n", " r0 = torch.cat((dist.Uniform(torch.ones(date_r0_switch - 1) * r0_1 - .2, \n", " torch.ones(date_r0_switch - 1) * r0_1 + .2)(), \n", " dist.Uniform(torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 - .2, \n", " torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 + .2)()))\n", " s_T, e_T, i_T, h_T, l_T, m_T, r_T = SEIR(r0)\n", " #print(r0)\n", " if mrs:\n", " r0_mrs_1 = pyro.sample(\"r0_mrs_1\", dist.Uniform(.0, 6.))\n", " r0_mrs_2 = pyro.sample(\"r0_mrs_2\", dist.Uniform(.0, 3.))\n", " r0_mrs = torch.cat((dist.Uniform(torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 - .2, \n", " torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 + .2)(), \n", " dist.Uniform(torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 - .2, \n", " torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 + .2)()))\n", " \n", " _, _, _, m_mrs_T, _ = SEIR_MRS(r0_mrs)\n", " for idx in range(n_useful_days):\n", " if idx > 16:\n", " if h_T[idx] < 0:\n", " print(h_T[idx])\n", " pyro.sample(\"h_%d\" % idx, noiser(h_T[idx]))\n", " pyro.sample(\"l_%d\" % idx, noiser(l_T[idx]))\n", " pyro.sample(\"m_%d\" % idx, noiser(m_T[idx]))\n", " if mrs:\n", " pyro.sample(\"m_mrs_%d\" % idx, noiser(m_mrs_T[idx]))\n", " if mrs and ret_data:\n", " return s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0, r0_mrs\n", " if ret_data:\n", " return s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0" ] }, { "cell_type": "code", "execution_count": 241, "metadata": {}, "outputs": [], "source": [ "data = {}\n", "hospi = data_df['n_hospitalized']\n", "l = data_df['n_icu']\n", "m = data_df['n_deaths']\n", "for idx in range(n_useful_days):\n", " if idx > 16:\n", " data[\"h_%d\" % idx] = torch.tensor(hospi[idx] - l[idx], dtype=torch.float)\n", " data[\"l_%d\" % idx] = torch.tensor(l[idx], dtype=torch.float)\n", " data[\"m_%d\" % idx] = m[idx] * frac_dh()\n", " if mrs:\n", " data[\"m_mrs_%d\" % idx] = m[idx]*(1-frac_dh())\n" ] }, { "cell_type": "code", "execution_count": 218, "metadata": {}, "outputs": [], "source": [ "def r0_guide(ret_data=False): \n", " a_r0_1 = pyro.param(\"a_r0_1\", torch.tensor(2.), constraint=constraints.positive)\n", " a_r0_2 = pyro.param(\"a_r0_2\", torch.tensor(.25), constraint=constraints.positive)\n", " b_r0_1 = pyro.param(\"b_r0_1\", torch.tensor(4.), constraint=constraints.positive)\n", " b_r0_2 = pyro.param(\"b_r0_2\", torch.tensor(1.25), constraint=constraints.positive)\n", " r0_1 = pyro.sample(\"r0_1\", dist.Uniform(a_r0_1, a_r0_1 + b_r0_1))\n", " r0_2 = pyro.sample(\"r0_2\", dist.Uniform(a_r0_2, a_r0_2 + b_r0_2))\n", " if mrs:\n", " a_r0_mrs_1 = pyro.param(\"a_r0_mrs_1\", torch.tensor(2.), constraint=constraints.positive)\n", " a_r0_mrs_2 = pyro.param(\"a_r0_mrs_2\", torch.tensor(.25), constraint=constraints.positive)\n", " b_r0_mrs_1 = pyro.param(\"b_r0_mrs_1\", torch.tensor(4.), constraint=constraints.positive)\n", " b_r0_mrs_2 = pyro.param(\"b_r0_mrs_2\", torch.tensor(1.25), constraint=constraints.positive)\n", " r0_mrs_1 = pyro.sample(\"r0_mrs_1\", dist.Uniform(a_r0_mrs_1, a_r0_mrs_1 + b_r0_mrs_1))\n", " r0_mrs_2 = pyro.sample(\"r0_mrs_2\", dist.Uniform(a_r0_mrs_2, a_r0_mrs_2 + b_r0_mrs_2))\n", " \n", " if ret_data:\n", " # R0 fluctuates around its means that are given by r0_1 and r0_2\n", " r0 = torch.cat((dist.Uniform(torch.ones(date_r0_switch - 1) * r0_1 - .2, \n", " torch.ones(date_r0_switch - 1) * r0_1 + .2)(), \n", " dist.Uniform(torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 - .2, \n", " torch.ones(n_useful_days - date_r0_switch + 1) * r0_2 + .2)()))\n", " s_T, e_T, i_T, h_T, l_T, m_T, r_T = SEIR(r0)\n", " #print(r0)\n", " if mrs:\n", " r0_mrs = torch.cat((dist.Uniform(torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 - .2, \n", " torch.ones(date_r0_switch_mrs - 1) * r0_mrs_1 + .2)(), \n", " dist.Uniform(torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 - .2, \n", " torch.ones(n_useful_days - date_r0_switch_mrs + 1) * r0_mrs_2 + .2)()))\n", "\n", " _, _, _, m_mrs_T, _ = SEIR_MRS(r0_mrs)\n", " for idx in range(n_useful_days):\n", " if idx > 16:\n", " pyro.sample(\"h_%d\" % idx, noiser(h_T[idx]))\n", " pyro.sample(\"l_%d\" % idx, noiser(l_T[idx]))\n", " pyro.sample(\"m_%d\" % idx, noiser(m_T[idx]))\n", " if mrs:\n", " pyro.sample(\"m_mrs_%d\" % idx, noiser(m_mrs_T[idx]))\n", " if mrs and ret_data:\n", " return s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0, r0_mrs\n", " if ret_data:\n", " return s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0" ] }, { "cell_type": "code", "execution_count": 219, "metadata": {}, "outputs": [], "source": [ "pyro.clear_param_store()\n", "# setup the optimizer\n", "adam_params = {\"lr\": 0.001, \"betas\": (0.90, 0.999)}\n", "optimizer = pyro.optim.Adam(adam_params)\n", "svi = pyro.infer.SVI(model=pyro.condition(SEIR_full_model_bis, data=data),\n", " guide=r0_guide,\n", " optim=optimizer,\n", " loss=pyro.infer.JitTrace_ELBO())\n" ] }, { "cell_type": "code", "execution_count": 226, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(2.1217, grad_fn=) tensor(6.3455, grad_fn=)\n", "r0 after lock-down tensor(0.2653, grad_fn=) tensor(1.5792, grad_fn=)\n", "r0 mrs before lock-down tensor(2.1240, grad_fn=) tensor(6.3358, grad_fn=)\n", "r0 mrs after lock-down tensor(0.2651, grad_fn=) tensor(1.5749, grad_fn=)\n", "275472.34375\n", "r0 before lock-down tensor(2.3210, grad_fn=) tensor(6.8859, grad_fn=)\n", "r0 after lock-down tensor(0.2899, grad_fn=) tensor(1.7015, grad_fn=)\n", "r0 mrs before lock-down tensor(2.3515, grad_fn=) tensor(6.9531, grad_fn=)\n", "r0 mrs after lock-down tensor(0.2941, grad_fn=) tensor(1.7203, grad_fn=)\n", "585381.0\n", "r0 before lock-down tensor(2.5228, grad_fn=) tensor(7.3841, grad_fn=)\n", "r0 after lock-down tensor(0.3144, grad_fn=) tensor(1.8260, grad_fn=)\n", "r0 mrs before lock-down tensor(2.6057, grad_fn=) tensor(7.6259, grad_fn=)\n", "r0 mrs after lock-down tensor(0.3276, grad_fn=) tensor(1.8828, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(2.6768, grad_fn=) tensor(7.6010, grad_fn=)\n", "r0 after lock-down tensor(0.3284, grad_fn=) tensor(1.8700, grad_fn=)\n", "r0 mrs before lock-down tensor(2.8704, grad_fn=) tensor(8.3393, grad_fn=)\n", "r0 mrs after lock-down tensor(0.3652, grad_fn=) tensor(2.0614, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(2.7621, grad_fn=) tensor(7.5925, grad_fn=)\n", "r0 after lock-down tensor(0.3320, grad_fn=) tensor(1.8527, grad_fn=)\n", "r0 mrs before lock-down tensor(3.1170, grad_fn=) tensor(8.8443, grad_fn=)\n", "r0 mrs after lock-down tensor(0.3969, grad_fn=) tensor(2.1780, grad_fn=)\n", "269947.3125\n", "r0 before lock-down tensor(2.8481, grad_fn=) tensor(7.6349, grad_fn=)\n", "r0 after lock-down tensor(0.3363, grad_fn=) tensor(1.8502, grad_fn=)\n", "r0 mrs before lock-down tensor(3.3403, grad_fn=) tensor(9.1441, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4254, grad_fn=) tensor(2.2847, grad_fn=)\n", "699314.6875\n", "r0 before lock-down tensor(2.9784, grad_fn=) tensor(7.7202, grad_fn=)\n", "r0 after lock-down tensor(0.3446, grad_fn=) tensor(1.8654, grad_fn=)\n", "r0 mrs before lock-down tensor(3.5205, grad_fn=) tensor(9.3419, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4456, grad_fn=) tensor(2.3310, grad_fn=)\n", "442595.46875\n", "r0 before lock-down tensor(3.0332, grad_fn=) tensor(7.6380, grad_fn=)\n", "r0 after lock-down tensor(0.3448, grad_fn=) tensor(1.8491, grad_fn=)\n", "r0 mrs before lock-down tensor(3.6650, grad_fn=) tensor(9.4813, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4620, grad_fn=) tensor(2.3666, grad_fn=)\n", "679968.25\n", "r0 before lock-down tensor(3.1139, grad_fn=) tensor(7.6639, grad_fn=)\n", "r0 after lock-down tensor(0.3471, grad_fn=) tensor(1.8357, grad_fn=)\n", "r0 mrs before lock-down tensor(3.7358, grad_fn=) tensor(9.3489, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4634, grad_fn=) tensor(2.3145, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.2497, grad_fn=) tensor(7.7905, grad_fn=)\n", "r0 after lock-down tensor(0.3547, grad_fn=) tensor(1.8390, grad_fn=)\n", "r0 mrs before lock-down tensor(3.8456, grad_fn=) tensor(9.3543, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4705, grad_fn=) tensor(2.2961, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.3035, grad_fn=) tensor(7.7239, grad_fn=)\n", "r0 after lock-down tensor(0.3544, grad_fn=) tensor(1.8207, grad_fn=)\n", "r0 mrs before lock-down tensor(4.0151, grad_fn=) tensor(9.5026, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4881, grad_fn=) tensor(2.3413, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.4018, grad_fn=) tensor(7.7564, grad_fn=)\n", "r0 after lock-down tensor(0.3590, grad_fn=) tensor(1.8232, grad_fn=)\n", "r0 mrs before lock-down tensor(4.1135, grad_fn=) tensor(9.5064, grad_fn=)\n", "r0 mrs after lock-down tensor(0.4954, grad_fn=) tensor(2.3303, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.4485, grad_fn=) tensor(7.6905, grad_fn=)\n", "r0 after lock-down tensor(0.3578, grad_fn=) tensor(1.7971, grad_fn=)\n", "r0 mrs before lock-down tensor(4.2458, grad_fn=) tensor(9.5687, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5060, grad_fn=) tensor(2.3280, grad_fn=)\n", "606454.625\n", "r0 before lock-down tensor(3.5405, grad_fn=) tensor(7.7237, grad_fn=)\n", "r0 after lock-down tensor(0.3607, grad_fn=) tensor(1.7914, grad_fn=)\n", "r0 mrs before lock-down tensor(4.3395, grad_fn=) tensor(9.5335, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5123, grad_fn=) tensor(2.3158, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.6269, grad_fn=) tensor(7.7347, grad_fn=)\n", "r0 after lock-down tensor(0.3625, grad_fn=) tensor(1.7733, grad_fn=)\n", "r0 mrs before lock-down tensor(4.4140, grad_fn=) tensor(9.4917, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5161, grad_fn=) tensor(2.2932, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.7902, grad_fn=) tensor(7.8886, grad_fn=)\n", "r0 after lock-down tensor(0.3718, grad_fn=) tensor(1.7968, grad_fn=)\n", "r0 mrs before lock-down tensor(4.4862, grad_fn=) tensor(9.4809, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5208, grad_fn=) tensor(2.2773, grad_fn=)\n", "99171.1171875\n", "r0 before lock-down tensor(3.8596, grad_fn=) tensor(7.8772, grad_fn=)\n", "r0 after lock-down tensor(0.3722, grad_fn=) tensor(1.7768, grad_fn=)\n", "r0 mrs before lock-down tensor(4.6468, grad_fn=) tensor(9.5957, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5343, grad_fn=) tensor(2.2880, grad_fn=)\n", "284389.46875\n", "r0 before lock-down tensor(3.8637, grad_fn=) tensor(7.7356, grad_fn=)\n", "r0 after lock-down tensor(0.3664, grad_fn=) tensor(1.7266, grad_fn=)\n", "r0 mrs before lock-down tensor(4.7763, grad_fn=) tensor(9.6623, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5441, grad_fn=) tensor(2.2839, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(3.9665, grad_fn=) tensor(7.7796, grad_fn=)\n", "r0 after lock-down tensor(0.3708, grad_fn=) tensor(1.7283, grad_fn=)\n", "r0 mrs before lock-down tensor(4.8806, grad_fn=) tensor(9.6937, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5539, grad_fn=) tensor(2.2912, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.0282, grad_fn=) tensor(7.7445, grad_fn=)\n", "r0 after lock-down tensor(0.3713, grad_fn=) tensor(1.7128, grad_fn=)\n", "r0 mrs before lock-down tensor(4.9176, grad_fn=) tensor(9.5853, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5541, grad_fn=) tensor(2.2581, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.0511, grad_fn=) tensor(7.6540, grad_fn=)\n", "r0 after lock-down tensor(0.3684, grad_fn=) tensor(1.6838, grad_fn=)\n", "r0 mrs before lock-down tensor(4.9512, grad_fn=) tensor(9.4826, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5547, grad_fn=) tensor(2.2283, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(4.1463, grad_fn=) tensor(7.6991, grad_fn=)\n", "r0 after lock-down tensor(0.3712, grad_fn=) tensor(1.6776, grad_fn=)\n", "r0 mrs before lock-down tensor(5.0984, grad_fn=) tensor(9.5834, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5675, grad_fn=) tensor(2.2388, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.2824, grad_fn=) tensor(7.8071, grad_fn=)\n", "r0 after lock-down tensor(0.3782, grad_fn=) tensor(1.6893, grad_fn=)\n", "r0 mrs before lock-down tensor(5.1681, grad_fn=) tensor(9.5776, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5734, grad_fn=) tensor(2.2305, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.3046, grad_fn=) tensor(7.7147, grad_fn=)\n", "r0 after lock-down tensor(0.3747, grad_fn=) tensor(1.6533, grad_fn=)\n", "r0 mrs before lock-down tensor(5.1957, grad_fn=) tensor(9.4872, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5753, grad_fn=) tensor(2.2071, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.3161, grad_fn=) tensor(7.6247, grad_fn=)\n", "r0 after lock-down tensor(0.3708, grad_fn=) tensor(1.6190, grad_fn=)\n", "r0 mrs before lock-down tensor(5.2315, grad_fn=) tensor(9.4081, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5750, grad_fn=) tensor(2.1691, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.4106, grad_fn=) tensor(7.6751, grad_fn=)\n", "r0 after lock-down tensor(0.3739, grad_fn=) tensor(1.6150, grad_fn=)\n", "r0 mrs before lock-down tensor(5.3078, grad_fn=) tensor(9.4137, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5816, grad_fn=) tensor(2.1633, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.4673, grad_fn=) tensor(7.6733, grad_fn=)\n", "r0 after lock-down tensor(0.3749, grad_fn=) tensor(1.6054, grad_fn=)\n", "r0 mrs before lock-down tensor(5.3244, grad_fn=) tensor(9.3372, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5815, grad_fn=) tensor(2.1351, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.5157, grad_fn=) tensor(7.6411, grad_fn=)\n", "r0 after lock-down tensor(0.3747, grad_fn=) tensor(1.5875, grad_fn=)\n", "r0 mrs before lock-down tensor(5.3960, grad_fn=) tensor(9.3271, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5860, grad_fn=) tensor(2.1179, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.5603, grad_fn=) tensor(7.6109, grad_fn=)\n", "r0 after lock-down tensor(0.3737, grad_fn=) tensor(1.5639, grad_fn=)\n", "r0 mrs before lock-down tensor(5.4188, grad_fn=) tensor(9.2699, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5874, grad_fn=) tensor(2.0955, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.6094, grad_fn=) tensor(7.6020, grad_fn=)\n", "r0 after lock-down tensor(0.3744, grad_fn=) tensor(1.5513, grad_fn=)\n", "r0 mrs before lock-down tensor(5.4940, grad_fn=) tensor(9.2765, grad_fn=)\n", "r0 mrs after lock-down tensor(0.5934, grad_fn=) tensor(2.0831, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.6947, grad_fn=) tensor(7.6427, grad_fn=)\n", "r0 after lock-down tensor(0.3773, grad_fn=) tensor(1.5515, grad_fn=)\n", "r0 mrs before lock-down tensor(5.6054, grad_fn=) tensor(9.3464, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6039, grad_fn=) tensor(2.0893, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.7545, grad_fn=) tensor(7.6501, grad_fn=)\n", "r0 after lock-down tensor(0.3786, grad_fn=) tensor(1.5431, grad_fn=)\n", "r0 mrs before lock-down tensor(5.7446, grad_fn=) tensor(9.4408, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6165, grad_fn=) tensor(2.0957, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.8139, grad_fn=) tensor(7.6450, grad_fn=)\n", "r0 after lock-down tensor(0.3787, grad_fn=) tensor(1.5231, grad_fn=)\n", "r0 mrs before lock-down tensor(5.7966, grad_fn=) tensor(9.4186, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6213, grad_fn=) tensor(2.0836, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.8619, grad_fn=) tensor(7.6299, grad_fn=)\n", "r0 after lock-down tensor(0.3787, grad_fn=) tensor(1.5092, grad_fn=)\n", "r0 mrs before lock-down tensor(5.8324, grad_fn=) tensor(9.3880, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6259, grad_fn=) tensor(2.0733, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.9347, grad_fn=) tensor(7.6492, grad_fn=)\n", "r0 after lock-down tensor(0.3811, grad_fn=) tensor(1.5030, grad_fn=)\n", "r0 mrs before lock-down tensor(5.9439, grad_fn=) tensor(9.4724, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6387, grad_fn=) tensor(2.0847, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(4.9506, grad_fn=) tensor(7.5931, grad_fn=)\n", "r0 after lock-down tensor(0.3794, grad_fn=) tensor(1.4841, grad_fn=)\n", "r0 mrs before lock-down tensor(5.9631, grad_fn=) tensor(9.4336, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6432, grad_fn=) tensor(2.0767, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.0323, grad_fn=) tensor(7.6293, grad_fn=)\n", "r0 after lock-down tensor(0.3818, grad_fn=) tensor(1.4764, grad_fn=)\n", "r0 mrs before lock-down tensor(5.9468, grad_fn=) tensor(9.3164, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6418, grad_fn=) tensor(2.0485, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.0249, grad_fn=) tensor(7.5511, grad_fn=)\n", "r0 after lock-down tensor(0.3790, grad_fn=) tensor(1.4553, grad_fn=)\n", "r0 mrs before lock-down tensor(6.0562, grad_fn=) tensor(9.3952, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6554, grad_fn=) tensor(2.0601, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.1122, grad_fn=) tensor(7.6020, grad_fn=)\n", "r0 after lock-down tensor(0.3829, grad_fn=) tensor(1.4597, grad_fn=)\n", "r0 mrs before lock-down tensor(6.0487, grad_fn=) tensor(9.3014, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6563, grad_fn=) tensor(2.0379, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.1341, grad_fn=) tensor(7.5669, grad_fn=)\n", "r0 after lock-down tensor(0.3822, grad_fn=) tensor(1.4463, grad_fn=)\n", "r0 mrs before lock-down tensor(6.0460, grad_fn=) tensor(9.2202, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6572, grad_fn=) tensor(2.0138, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.1566, grad_fn=) tensor(7.5227, grad_fn=)\n", "r0 after lock-down tensor(0.3809, grad_fn=) tensor(1.4291, grad_fn=)\n", "r0 mrs before lock-down tensor(6.0945, grad_fn=) tensor(9.2184, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6640, grad_fn=) tensor(2.0058, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.1776, grad_fn=) tensor(7.4874, grad_fn=)\n", "r0 after lock-down tensor(0.3794, grad_fn=) tensor(1.4058, grad_fn=)\n", "r0 mrs before lock-down tensor(6.1965, grad_fn=) tensor(9.2874, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6766, grad_fn=) tensor(2.0096, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(5.2427, grad_fn=) tensor(7.5158, grad_fn=)\n", "r0 after lock-down tensor(0.3815, grad_fn=) tensor(1.4027, grad_fn=)\n", "r0 mrs before lock-down tensor(6.2766, grad_fn=) tensor(9.3105, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6874, grad_fn=) tensor(2.0101, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.2932, grad_fn=) tensor(7.5220, grad_fn=)\n", "r0 after lock-down tensor(0.3843, grad_fn=) tensor(1.4006, grad_fn=)\n", "r0 mrs before lock-down tensor(6.2882, grad_fn=) tensor(9.2752, grad_fn=)\n", "r0 mrs after lock-down tensor(0.6926, grad_fn=) tensor(2.0000, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.3195, grad_fn=) tensor(7.5058, grad_fn=)\n", "r0 after lock-down tensor(0.3846, grad_fn=) tensor(1.3893, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3529, grad_fn=) tensor(9.2974, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7026, grad_fn=) tensor(1.9981, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.3047, grad_fn=) tensor(7.4313, grad_fn=)\n", "r0 after lock-down tensor(0.3808, grad_fn=) tensor(1.3608, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4064, grad_fn=) tensor(9.2978, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7105, grad_fn=) tensor(1.9899, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.3572, grad_fn=) tensor(7.4454, grad_fn=)\n", "r0 after lock-down tensor(0.3824, grad_fn=) tensor(1.3502, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4181, grad_fn=) tensor(9.2384, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7135, grad_fn=) tensor(1.9692, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.4101, grad_fn=) tensor(7.4614, grad_fn=)\n", "r0 after lock-down tensor(0.3837, grad_fn=) tensor(1.3394, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3734, grad_fn=) tensor(9.1185, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7104, grad_fn=) tensor(1.9348, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.3950, grad_fn=) tensor(7.4002, grad_fn=)\n", "r0 after lock-down tensor(0.3807, grad_fn=) tensor(1.3166, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4458, grad_fn=) tensor(9.1603, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7241, grad_fn=) tensor(1.9479, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.4868, grad_fn=) tensor(7.4670, grad_fn=)\n", "r0 after lock-down tensor(0.3854, grad_fn=) tensor(1.3203, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4375, grad_fn=) tensor(9.0879, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7281, grad_fn=) tensor(1.9368, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.5074, grad_fn=) tensor(7.4531, grad_fn=)\n", "r0 after lock-down tensor(0.3858, grad_fn=) tensor(1.3114, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5361, grad_fn=) tensor(9.1546, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7424, grad_fn=) tensor(1.9414, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.5077, grad_fn=) tensor(7.4053, grad_fn=)\n", "r0 after lock-down tensor(0.3843, grad_fn=) tensor(1.2921, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5226, grad_fn=) tensor(9.0680, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7441, grad_fn=) tensor(1.9167, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.5588, grad_fn=) tensor(7.4219, grad_fn=)\n", "r0 after lock-down tensor(0.3870, grad_fn=) tensor(1.2840, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5425, grad_fn=) tensor(9.0474, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7523, grad_fn=) tensor(1.9152, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.5650, grad_fn=) tensor(7.3910, grad_fn=)\n", "r0 after lock-down tensor(0.3869, grad_fn=) tensor(1.2746, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5510, grad_fn=) tensor(9.0098, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7594, grad_fn=) tensor(1.9116, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.6180, grad_fn=) tensor(7.4155, grad_fn=)\n", "r0 after lock-down tensor(0.3890, grad_fn=) tensor(1.2695, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6255, grad_fn=) tensor(9.0530, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7740, grad_fn=) tensor(1.9209, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.6414, grad_fn=) tensor(7.4061, grad_fn=)\n", "r0 after lock-down tensor(0.3908, grad_fn=) tensor(1.2661, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6465, grad_fn=) tensor(9.0227, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7819, grad_fn=) tensor(1.9127, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.6350, grad_fn=) tensor(7.3524, grad_fn=)\n", "r0 after lock-down tensor(0.3900, grad_fn=) tensor(1.2531, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7254, grad_fn=) tensor(9.0730, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7957, grad_fn=) tensor(1.9113, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.6801, grad_fn=) tensor(7.3648, grad_fn=)\n", "r0 after lock-down tensor(0.3930, grad_fn=) tensor(1.2501, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6769, grad_fn=) tensor(8.9600, grad_fn=)\n", "r0 mrs after lock-down tensor(0.7970, grad_fn=) tensor(1.8929, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.7099, grad_fn=) tensor(7.3627, grad_fn=)\n", "r0 after lock-down tensor(0.3949, grad_fn=) tensor(1.2428, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7486, grad_fn=) tensor(9.0012, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8132, grad_fn=) tensor(1.9040, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.7582, grad_fn=) tensor(7.3883, grad_fn=)\n", "r0 after lock-down tensor(0.3985, grad_fn=) tensor(1.2435, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8419, grad_fn=) tensor(9.0684, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8310, grad_fn=) tensor(1.9119, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.7637, grad_fn=) tensor(7.3571, grad_fn=)\n", "r0 after lock-down tensor(0.3990, grad_fn=) tensor(1.2345, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8310, grad_fn=) tensor(9.0023, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8375, grad_fn=) tensor(1.9012, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.8168, grad_fn=) tensor(7.3787, grad_fn=)\n", "r0 after lock-down tensor(0.4024, grad_fn=) tensor(1.2327, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8897, grad_fn=) tensor(9.0270, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8544, grad_fn=) tensor(1.9116, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.8489, grad_fn=) tensor(7.3829, grad_fn=)\n", "r0 after lock-down tensor(0.4050, grad_fn=) tensor(1.2289, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9148, grad_fn=) tensor(9.0078, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8663, grad_fn=) tensor(1.9103, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(5.8927, grad_fn=) tensor(7.3990, grad_fn=)\n", "r0 after lock-down tensor(0.4087, grad_fn=) tensor(1.2287, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9354, grad_fn=) tensor(8.9885, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8789, grad_fn=) tensor(1.9107, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.7923, grad_fn=) tensor(7.2550, grad_fn=)\n", "r0 after lock-down tensor(0.4026, grad_fn=) tensor(1.2040, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8612, grad_fn=) tensor(8.8570, grad_fn=)\n", "r0 mrs after lock-down tensor(0.8786, grad_fn=) tensor(1.8918, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.8673, grad_fn=) tensor(7.3099, grad_fn=)\n", "r0 after lock-down tensor(0.4078, grad_fn=) tensor(1.2067, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9435, grad_fn=) tensor(8.9132, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9006, grad_fn=) tensor(1.9083, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9365, grad_fn=) tensor(7.3562, grad_fn=)\n", "r0 after lock-down tensor(0.4115, grad_fn=) tensor(1.2021, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9072, grad_fn=) tensor(8.8336, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9061, grad_fn=) tensor(1.8958, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9049, grad_fn=) tensor(7.2926, grad_fn=)\n", "r0 after lock-down tensor(0.4102, grad_fn=) tensor(1.1902, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9028, grad_fn=) tensor(8.7982, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9173, grad_fn=) tensor(1.8959, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9641, grad_fn=) tensor(7.3298, grad_fn=)\n", "r0 after lock-down tensor(0.4159, grad_fn=) tensor(1.1943, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9127, grad_fn=) tensor(8.7696, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9299, grad_fn=) tensor(1.8960, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.8905, grad_fn=) tensor(7.2157, grad_fn=)\n", "r0 after lock-down tensor(0.4116, grad_fn=) tensor(1.1722, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8899, grad_fn=) tensor(8.7047, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9375, grad_fn=) tensor(1.8876, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9600, grad_fn=) tensor(7.2680, grad_fn=)\n", "r0 after lock-down tensor(0.4175, grad_fn=) tensor(1.1742, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8991, grad_fn=) tensor(8.6814, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9515, grad_fn=) tensor(1.8912, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9730, grad_fn=) tensor(7.2627, grad_fn=)\n", "r0 after lock-down tensor(0.4200, grad_fn=) tensor(1.1713, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9329, grad_fn=) tensor(8.6828, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9691, grad_fn=) tensor(1.9006, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9948, grad_fn=) tensor(7.2586, grad_fn=)\n", "r0 after lock-down tensor(0.4235, grad_fn=) tensor(1.1683, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9822, grad_fn=) tensor(8.7103, grad_fn=)\n", "r0 mrs after lock-down tensor(0.9893, grad_fn=) tensor(1.9109, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(5.9847, grad_fn=) tensor(7.2217, grad_fn=)\n", "r0 after lock-down tensor(0.4240, grad_fn=) tensor(1.1568, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9827, grad_fn=) tensor(8.6817, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0031, grad_fn=) tensor(1.9114, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.0480, grad_fn=) tensor(7.2672, grad_fn=)\n", "r0 after lock-down tensor(0.4302, grad_fn=) tensor(1.1589, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9161, grad_fn=) tensor(8.5761, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0063, grad_fn=) tensor(1.8964, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.0272, grad_fn=) tensor(7.2201, grad_fn=)\n", "r0 after lock-down tensor(0.4310, grad_fn=) tensor(1.1519, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9519, grad_fn=) tensor(8.5898, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0261, grad_fn=) tensor(1.9047, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.0376, grad_fn=) tensor(7.2062, grad_fn=)\n", "r0 after lock-down tensor(0.4347, grad_fn=) tensor(1.1460, grad_fn=)\n", "r0 mrs before lock-down tensor(7.0041, grad_fn=) tensor(8.6208, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0489, grad_fn=) tensor(1.9180, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.0252, grad_fn=) tensor(7.1685, grad_fn=)\n", "r0 after lock-down tensor(0.4352, grad_fn=) tensor(1.1345, grad_fn=)\n", "r0 mrs before lock-down tensor(7.0316, grad_fn=) tensor(8.6228, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0698, grad_fn=) tensor(1.9294, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.0755, grad_fn=) tensor(7.2031, grad_fn=)\n", "r0 after lock-down tensor(0.4421, grad_fn=) tensor(1.1347, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9545, grad_fn=) tensor(8.5080, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0719, grad_fn=) tensor(1.9136, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1123, grad_fn=) tensor(7.2191, grad_fn=)\n", "r0 after lock-down tensor(0.4478, grad_fn=) tensor(1.1348, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8992, grad_fn=) tensor(8.4136, grad_fn=)\n", "r0 mrs after lock-down tensor(1.0770, grad_fn=) tensor(1.9003, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1383, grad_fn=) tensor(7.2226, grad_fn=)\n", "r0 after lock-down tensor(0.4522, grad_fn=) tensor(1.1333, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9451, grad_fn=) tensor(8.4367, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1006, grad_fn=) tensor(1.9151, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1365, grad_fn=) tensor(7.1969, grad_fn=)\n", "r0 after lock-down tensor(0.4545, grad_fn=) tensor(1.1266, grad_fn=)\n", "r0 mrs before lock-down tensor(7.0044, grad_fn=) tensor(8.4736, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1275, grad_fn=) tensor(1.9312, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1672, grad_fn=) tensor(7.2100, grad_fn=)\n", "r0 after lock-down tensor(0.4597, grad_fn=) tensor(1.1260, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9732, grad_fn=) tensor(8.4098, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1403, grad_fn=) tensor(1.9324, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1419, grad_fn=) tensor(7.1600, grad_fn=)\n", "r0 after lock-down tensor(0.4612, grad_fn=) tensor(1.1170, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9266, grad_fn=) tensor(8.3284, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1478, grad_fn=) tensor(1.9226, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.1725, grad_fn=) tensor(7.1730, grad_fn=)\n", "r0 after lock-down tensor(0.4663, grad_fn=) tensor(1.1163, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9168, grad_fn=) tensor(8.2979, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1662, grad_fn=) tensor(1.9329, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1691, grad_fn=) tensor(7.1494, grad_fn=)\n", "r0 after lock-down tensor(0.4692, grad_fn=) tensor(1.1140, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8943, grad_fn=) tensor(8.2480, grad_fn=)\n", "r0 mrs after lock-down tensor(1.1795, grad_fn=) tensor(1.9333, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1469, grad_fn=) tensor(7.1061, grad_fn=)\n", "r0 after lock-down tensor(0.4711, grad_fn=) tensor(1.1083, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9341, grad_fn=) tensor(8.2733, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2066, grad_fn=) tensor(1.9525, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1777, grad_fn=) tensor(7.1171, grad_fn=)\n", "r0 after lock-down tensor(0.4760, grad_fn=) tensor(1.1082, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9456, grad_fn=) tensor(8.2633, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2297, grad_fn=) tensor(1.9688, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.1727, grad_fn=) tensor(7.0946, grad_fn=)\n", "r0 after lock-down tensor(0.4800, grad_fn=) tensor(1.1040, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9592, grad_fn=) tensor(8.2574, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2519, grad_fn=) tensor(1.9798, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2303, grad_fn=) tensor(7.1378, grad_fn=)\n", "r0 after lock-down tensor(0.4877, grad_fn=) tensor(1.1069, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9289, grad_fn=) tensor(8.1987, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2669, grad_fn=) tensor(1.9831, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2487, grad_fn=) tensor(7.1439, grad_fn=)\n", "r0 after lock-down tensor(0.4939, grad_fn=) tensor(1.1063, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8895, grad_fn=) tensor(8.1335, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2771, grad_fn=) tensor(1.9767, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2538, grad_fn=) tensor(7.1347, grad_fn=)\n", "r0 after lock-down tensor(0.4990, grad_fn=) tensor(1.1071, grad_fn=)\n", "r0 mrs before lock-down tensor(6.9037, grad_fn=) tensor(8.1271, grad_fn=)\n", "r0 mrs after lock-down tensor(1.3002, grad_fn=) tensor(1.9883, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2083, grad_fn=) tensor(7.0670, grad_fn=)\n", "r0 after lock-down tensor(0.4984, grad_fn=) tensor(1.0949, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8176, grad_fn=) tensor(8.0052, grad_fn=)\n", "r0 mrs after lock-down tensor(1.2997, grad_fn=) tensor(1.9671, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2145, grad_fn=) tensor(7.0515, grad_fn=)\n", "r0 after lock-down tensor(0.5029, grad_fn=) tensor(1.0890, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8702, grad_fn=) tensor(8.0402, grad_fn=)\n", "r0 mrs after lock-down tensor(1.3301, grad_fn=) tensor(1.9905, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2318, grad_fn=) tensor(7.0532, grad_fn=)\n", "r0 after lock-down tensor(0.5078, grad_fn=) tensor(1.0857, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8760, grad_fn=) tensor(8.0325, grad_fn=)\n", "r0 mrs after lock-down tensor(1.3532, grad_fn=) tensor(2.0062, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2595, grad_fn=) tensor(7.0653, grad_fn=)\n", "r0 after lock-down tensor(0.5150, grad_fn=) tensor(1.0862, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8679, grad_fn=) tensor(8.0026, grad_fn=)\n", "r0 mrs after lock-down tensor(1.3706, grad_fn=) tensor(2.0102, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3316, grad_fn=) tensor(7.1250, grad_fn=)\n", "r0 after lock-down tensor(0.5257, grad_fn=) tensor(1.0951, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8420, grad_fn=) tensor(7.9536, grad_fn=)\n", "r0 mrs after lock-down tensor(1.3866, grad_fn=) tensor(2.0160, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2690, grad_fn=) tensor(7.0449, grad_fn=)\n", "r0 after lock-down tensor(0.5262, grad_fn=) tensor(1.0873, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8318, grad_fn=) tensor(7.9241, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4055, grad_fn=) tensor(2.0267, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3078, grad_fn=) tensor(7.0678, grad_fn=)\n", "r0 after lock-down tensor(0.5340, grad_fn=) tensor(1.0883, grad_fn=)\n", "r0 mrs before lock-down tensor(6.8077, grad_fn=) tensor(7.8812, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4221, grad_fn=) tensor(2.0342, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3255, grad_fn=) tensor(7.0733, grad_fn=)\n", "r0 after lock-down tensor(0.5410, grad_fn=) tensor(1.0912, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7652, grad_fn=) tensor(7.8206, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4330, grad_fn=) tensor(2.0334, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2944, grad_fn=) tensor(7.0272, grad_fn=)\n", "r0 after lock-down tensor(0.5424, grad_fn=) tensor(1.0861, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7282, grad_fn=) tensor(7.7608, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4436, grad_fn=) tensor(2.0314, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2807, grad_fn=) tensor(6.9991, grad_fn=)\n", "r0 after lock-down tensor(0.5455, grad_fn=) tensor(1.0828, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7993, grad_fn=) tensor(7.8207, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4823, grad_fn=) tensor(2.0655, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2691, grad_fn=) tensor(6.9740, grad_fn=)\n", "r0 after lock-down tensor(0.5481, grad_fn=) tensor(1.0763, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7140, grad_fn=) tensor(7.7111, grad_fn=)\n", "r0 mrs after lock-down tensor(1.4818, grad_fn=) tensor(2.0504, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.2818, grad_fn=) tensor(6.9732, grad_fn=)\n", "r0 after lock-down tensor(0.5533, grad_fn=) tensor(1.0751, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7315, grad_fn=) tensor(7.7108, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5038, grad_fn=) tensor(2.0616, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3599, grad_fn=) tensor(7.0430, grad_fn=)\n", "r0 after lock-down tensor(0.5665, grad_fn=) tensor(1.0843, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7049, grad_fn=) tensor(7.6672, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5169, grad_fn=) tensor(2.0632, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.3517, grad_fn=) tensor(7.0202, grad_fn=)\n", "r0 after lock-down tensor(0.5718, grad_fn=) tensor(1.0814, grad_fn=)\n", "r0 mrs before lock-down tensor(6.7438, grad_fn=) tensor(7.6954, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5490, grad_fn=) tensor(2.0901, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3565, grad_fn=) tensor(7.0118, grad_fn=)\n", "r0 after lock-down tensor(0.5766, grad_fn=) tensor(1.0788, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6702, grad_fn=) tensor(7.6027, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5490, grad_fn=) tensor(2.0768, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3529, grad_fn=) tensor(6.9935, grad_fn=)\n", "r0 after lock-down tensor(0.5821, grad_fn=) tensor(1.0762, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6703, grad_fn=) tensor(7.5873, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5685, grad_fn=) tensor(2.0880, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3939, grad_fn=) tensor(7.0255, grad_fn=)\n", "r0 after lock-down tensor(0.5918, grad_fn=) tensor(1.0810, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6395, grad_fn=) tensor(7.5389, grad_fn=)\n", "r0 mrs after lock-down tensor(1.5796, grad_fn=) tensor(2.0880, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3869, grad_fn=) tensor(7.0062, grad_fn=)\n", "r0 after lock-down tensor(0.5965, grad_fn=) tensor(1.0798, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6551, grad_fn=) tensor(7.5412, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6025, grad_fn=) tensor(2.1033, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3492, grad_fn=) tensor(6.9552, grad_fn=)\n", "r0 after lock-down tensor(0.5981, grad_fn=) tensor(1.0733, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6152, grad_fn=) tensor(7.4815, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6091, grad_fn=) tensor(2.0991, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3543, grad_fn=) tensor(6.9491, grad_fn=)\n", "r0 after lock-down tensor(0.6029, grad_fn=) tensor(1.0727, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6288, grad_fn=) tensor(7.4813, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6311, grad_fn=) tensor(2.1144, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3455, grad_fn=) tensor(6.9288, grad_fn=)\n", "r0 after lock-down tensor(0.6081, grad_fn=) tensor(1.0710, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6016, grad_fn=) tensor(7.4425, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6420, grad_fn=) tensor(2.1173, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3818, grad_fn=) tensor(6.9567, grad_fn=)\n", "r0 after lock-down tensor(0.6175, grad_fn=) tensor(1.0764, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5866, grad_fn=) tensor(7.4107, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6541, grad_fn=) tensor(2.1184, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3583, grad_fn=) tensor(6.9195, grad_fn=)\n", "r0 after lock-down tensor(0.6208, grad_fn=) tensor(1.0715, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5834, grad_fn=) tensor(7.3972, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6703, grad_fn=) tensor(2.1262, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3848, grad_fn=) tensor(6.9375, grad_fn=)\n", "r0 after lock-down tensor(0.6292, grad_fn=) tensor(1.0748, grad_fn=)\n", "r0 mrs before lock-down tensor(6.6024, grad_fn=) tensor(7.4007, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6911, grad_fn=) tensor(2.1385, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3819, grad_fn=) tensor(6.9234, grad_fn=)\n", "r0 after lock-down tensor(0.6335, grad_fn=) tensor(1.0724, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5697, grad_fn=) tensor(7.3510, grad_fn=)\n", "r0 mrs after lock-down tensor(1.6967, grad_fn=) tensor(2.1338, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3872, grad_fn=) tensor(6.9200, grad_fn=)\n", "r0 after lock-down tensor(0.6389, grad_fn=) tensor(1.0737, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5565, grad_fn=) tensor(7.3237, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7087, grad_fn=) tensor(2.1373, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3560, grad_fn=) tensor(6.8787, grad_fn=)\n", "r0 after lock-down tensor(0.6402, grad_fn=) tensor(1.0654, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5006, grad_fn=) tensor(7.2549, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7072, grad_fn=) tensor(2.1259, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3947, grad_fn=) tensor(6.9093, grad_fn=)\n", "r0 after lock-down tensor(0.6493, grad_fn=) tensor(1.0687, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5242, grad_fn=) tensor(7.2700, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7294, grad_fn=) tensor(2.1436, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3748, grad_fn=) tensor(6.8790, grad_fn=)\n", "r0 after lock-down tensor(0.6516, grad_fn=) tensor(1.0626, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5050, grad_fn=) tensor(7.2374, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7365, grad_fn=) tensor(2.1427, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4098, grad_fn=) tensor(6.9058, grad_fn=)\n", "r0 after lock-down tensor(0.6592, grad_fn=) tensor(1.0658, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5119, grad_fn=) tensor(7.2325, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7519, grad_fn=) tensor(2.1503, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3969, grad_fn=) tensor(6.8837, grad_fn=)\n", "r0 after lock-down tensor(0.6642, grad_fn=) tensor(1.0648, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5343, grad_fn=) tensor(7.2436, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7724, grad_fn=) tensor(2.1628, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4073, grad_fn=) tensor(6.8839, grad_fn=)\n", "r0 after lock-down tensor(0.6702, grad_fn=) tensor(1.0621, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5209, grad_fn=) tensor(7.2151, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7822, grad_fn=) tensor(2.1623, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4347, grad_fn=) tensor(6.9025, grad_fn=)\n", "r0 after lock-down tensor(0.6780, grad_fn=) tensor(1.0648, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4769, grad_fn=) tensor(7.1555, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7818, grad_fn=) tensor(2.1533, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3991, grad_fn=) tensor(6.8571, grad_fn=)\n", "r0 after lock-down tensor(0.6792, grad_fn=) tensor(1.0612, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5010, grad_fn=) tensor(7.1696, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8004, grad_fn=) tensor(2.1661, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.3906, grad_fn=) tensor(6.8395, grad_fn=)\n", "r0 after lock-down tensor(0.6835, grad_fn=) tensor(1.0591, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4621, grad_fn=) tensor(7.1196, grad_fn=)\n", "r0 mrs after lock-down tensor(1.7979, grad_fn=) tensor(2.1553, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4135, grad_fn=) tensor(6.8555, grad_fn=)\n", "r0 after lock-down tensor(0.6913, grad_fn=) tensor(1.0600, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4378, grad_fn=) tensor(7.0825, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8020, grad_fn=) tensor(2.1522, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3824, grad_fn=) tensor(6.8155, grad_fn=)\n", "r0 after lock-down tensor(0.6932, grad_fn=) tensor(1.0549, grad_fn=)\n", "r0 mrs before lock-down tensor(6.5092, grad_fn=) tensor(7.1502, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8388, grad_fn=) tensor(2.1865, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4231, grad_fn=) tensor(6.8489, grad_fn=)\n", "r0 after lock-down tensor(0.7002, grad_fn=) tensor(1.0565, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4615, grad_fn=) tensor(7.0870, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8344, grad_fn=) tensor(2.1738, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.3873, grad_fn=) tensor(6.8029, grad_fn=)\n", "r0 after lock-down tensor(0.6993, grad_fn=) tensor(1.0489, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4952, grad_fn=) tensor(7.1110, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8531, grad_fn=) tensor(2.1866, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4494, grad_fn=) tensor(6.8608, grad_fn=)\n", "r0 after lock-down tensor(0.7099, grad_fn=) tensor(1.0559, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4994, grad_fn=) tensor(7.1054, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8646, grad_fn=) tensor(2.1925, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4400, grad_fn=) tensor(6.8429, grad_fn=)\n", "r0 after lock-down tensor(0.7138, grad_fn=) tensor(1.0529, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4656, grad_fn=) tensor(7.0570, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8613, grad_fn=) tensor(2.1804, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4325, grad_fn=) tensor(6.8282, grad_fn=)\n", "r0 after lock-down tensor(0.7170, grad_fn=) tensor(1.0496, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4368, grad_fn=) tensor(7.0170, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8620, grad_fn=) tensor(2.1745, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4558, grad_fn=) tensor(6.8454, grad_fn=)\n", "r0 after lock-down tensor(0.7231, grad_fn=) tensor(1.0510, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4335, grad_fn=) tensor(7.0039, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8702, grad_fn=) tensor(2.1770, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4311, grad_fn=) tensor(6.8118, grad_fn=)\n", "r0 after lock-down tensor(0.7263, grad_fn=) tensor(1.0477, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4426, grad_fn=) tensor(7.0055, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8812, grad_fn=) tensor(2.1819, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4648, grad_fn=) tensor(6.8399, grad_fn=)\n", "r0 after lock-down tensor(0.7359, grad_fn=) tensor(1.0537, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4208, grad_fn=) tensor(6.9747, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8814, grad_fn=) tensor(2.1758, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4381, grad_fn=) tensor(6.8063, grad_fn=)\n", "r0 after lock-down tensor(0.7361, grad_fn=) tensor(1.0498, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4393, grad_fn=) tensor(6.9813, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8953, grad_fn=) tensor(2.1845, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4519, grad_fn=) tensor(6.8139, grad_fn=)\n", "r0 after lock-down tensor(0.7451, grad_fn=) tensor(1.0553, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4030, grad_fn=) tensor(6.9339, grad_fn=)\n", "r0 mrs after lock-down tensor(1.8908, grad_fn=) tensor(2.1730, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4568, grad_fn=) tensor(6.8124, grad_fn=)\n", "r0 after lock-down tensor(0.7509, grad_fn=) tensor(1.0571, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4250, grad_fn=) tensor(6.9487, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9068, grad_fn=) tensor(2.1848, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4141, grad_fn=) tensor(6.7621, grad_fn=)\n", "r0 after lock-down tensor(0.7492, grad_fn=) tensor(1.0498, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4186, grad_fn=) tensor(6.9331, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9107, grad_fn=) tensor(2.1834, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4535, grad_fn=) tensor(6.7953, grad_fn=)\n", "r0 after lock-down tensor(0.7557, grad_fn=) tensor(1.0510, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3960, grad_fn=) tensor(6.9023, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9099, grad_fn=) tensor(2.1779, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4249, grad_fn=) tensor(6.7587, grad_fn=)\n", "r0 after lock-down tensor(0.7523, grad_fn=) tensor(1.0410, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4037, grad_fn=) tensor(6.9021, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9181, grad_fn=) tensor(2.1810, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4523, grad_fn=) tensor(6.7807, grad_fn=)\n", "r0 after lock-down tensor(0.7590, grad_fn=) tensor(1.0436, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4081, grad_fn=) tensor(6.8998, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9248, grad_fn=) tensor(2.1835, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4853, grad_fn=) tensor(6.8086, grad_fn=)\n", "r0 after lock-down tensor(0.7679, grad_fn=) tensor(1.0491, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4069, grad_fn=) tensor(6.8904, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9310, grad_fn=) tensor(2.1846, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4489, grad_fn=) tensor(6.7652, grad_fn=)\n", "r0 after lock-down tensor(0.7660, grad_fn=) tensor(1.0407, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4045, grad_fn=) tensor(6.8810, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9375, grad_fn=) tensor(2.1876, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4660, grad_fn=) tensor(6.7763, grad_fn=)\n", "r0 after lock-down tensor(0.7690, grad_fn=) tensor(1.0391, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4064, grad_fn=) tensor(6.8743, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9418, grad_fn=) tensor(2.1863, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.4589, grad_fn=) tensor(6.7636, grad_fn=)\n", "r0 after lock-down tensor(0.7691, grad_fn=) tensor(1.0347, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4081, grad_fn=) tensor(6.8686, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9467, grad_fn=) tensor(2.1865, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4566, grad_fn=) tensor(6.7551, grad_fn=)\n", "r0 after lock-down tensor(0.7707, grad_fn=) tensor(1.0323, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4120, grad_fn=) tensor(6.8638, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9523, grad_fn=) tensor(2.1877, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4877, grad_fn=) tensor(6.7820, grad_fn=)\n", "r0 after lock-down tensor(0.7803, grad_fn=) tensor(1.0384, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4279, grad_fn=) tensor(6.8739, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9642, grad_fn=) tensor(2.1964, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4828, grad_fn=) tensor(6.7720, grad_fn=)\n", "r0 after lock-down tensor(0.7828, grad_fn=) tensor(1.0370, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4131, grad_fn=) tensor(6.8521, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9645, grad_fn=) tensor(2.1920, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4770, grad_fn=) tensor(6.7608, grad_fn=)\n", "r0 after lock-down tensor(0.7876, grad_fn=) tensor(1.0377, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4103, grad_fn=) tensor(6.8440, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9667, grad_fn=) tensor(2.1896, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4837, grad_fn=) tensor(6.7626, grad_fn=)\n", "r0 after lock-down tensor(0.7921, grad_fn=) tensor(1.0384, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3741, grad_fn=) tensor(6.7984, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9577, grad_fn=) tensor(2.1756, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4723, grad_fn=) tensor(6.7465, grad_fn=)\n", "r0 after lock-down tensor(0.7905, grad_fn=) tensor(1.0313, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3924, grad_fn=) tensor(6.8114, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9672, grad_fn=) tensor(2.1815, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4776, grad_fn=) tensor(6.7466, grad_fn=)\n", "r0 after lock-down tensor(0.7941, grad_fn=) tensor(1.0306, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4158, grad_fn=) tensor(6.8297, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9802, grad_fn=) tensor(2.1921, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4852, grad_fn=) tensor(6.7500, grad_fn=)\n", "r0 after lock-down tensor(0.7985, grad_fn=) tensor(1.0318, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4070, grad_fn=) tensor(6.8133, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9800, grad_fn=) tensor(2.1873, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5095, grad_fn=) tensor(6.7696, grad_fn=)\n", "r0 after lock-down tensor(0.8040, grad_fn=) tensor(1.0334, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3988, grad_fn=) tensor(6.7983, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9801, grad_fn=) tensor(2.1831, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4713, grad_fn=) tensor(6.7261, grad_fn=)\n", "r0 after lock-down tensor(0.7993, grad_fn=) tensor(1.0238, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4134, grad_fn=) tensor(6.8056, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9879, grad_fn=) tensor(2.1871, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5044, grad_fn=) tensor(6.7551, grad_fn=)\n", "r0 after lock-down tensor(0.8078, grad_fn=) tensor(1.0295, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4091, grad_fn=) tensor(6.7942, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9882, grad_fn=) tensor(2.1831, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4814, grad_fn=) tensor(6.7272, grad_fn=)\n", "r0 after lock-down tensor(0.8084, grad_fn=) tensor(1.0283, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4060, grad_fn=) tensor(6.7843, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9883, grad_fn=) tensor(2.1792, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4890, grad_fn=) tensor(6.7307, grad_fn=)\n", "r0 after lock-down tensor(0.8120, grad_fn=) tensor(1.0281, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4138, grad_fn=) tensor(6.7851, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9928, grad_fn=) tensor(2.1800, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4621, grad_fn=) tensor(6.6985, grad_fn=)\n", "r0 after lock-down tensor(0.8097, grad_fn=) tensor(1.0216, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4089, grad_fn=) tensor(6.7745, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9931, grad_fn=) tensor(2.1761, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5078, grad_fn=) tensor(6.7407, grad_fn=)\n", "r0 after lock-down tensor(0.8165, grad_fn=) tensor(1.0256, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3797, grad_fn=) tensor(6.7377, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9846, grad_fn=) tensor(2.1632, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4992, grad_fn=) tensor(6.7286, grad_fn=)\n", "r0 after lock-down tensor(0.8184, grad_fn=) tensor(1.0245, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4168, grad_fn=) tensor(6.7708, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9997, grad_fn=) tensor(2.1761, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4814, grad_fn=) tensor(6.7064, grad_fn=)\n", "r0 after lock-down tensor(0.8184, grad_fn=) tensor(1.0210, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4141, grad_fn=) tensor(6.7611, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9996, grad_fn=) tensor(2.1723, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4904, grad_fn=) tensor(6.7122, grad_fn=)\n", "r0 after lock-down tensor(0.8212, grad_fn=) tensor(1.0210, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4243, grad_fn=) tensor(6.7663, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0044, grad_fn=) tensor(2.1738, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5318, grad_fn=) tensor(6.7499, grad_fn=)\n", "r0 after lock-down tensor(0.8263, grad_fn=) tensor(1.0234, grad_fn=)\n", "r0 mrs before lock-down tensor(6.3948, grad_fn=) tensor(6.7295, grad_fn=)\n", "r0 mrs after lock-down tensor(1.9972, grad_fn=) tensor(2.1630, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5240, grad_fn=) tensor(6.7374, grad_fn=)\n", "r0 after lock-down tensor(0.8277, grad_fn=) tensor(1.0209, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4029, grad_fn=) tensor(6.7319, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0028, grad_fn=) tensor(2.1662, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.5061, grad_fn=) tensor(6.7142, grad_fn=)\n", "r0 after lock-down tensor(0.8285, grad_fn=) tensor(1.0177, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4349, grad_fn=) tensor(6.7607, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0140, grad_fn=) tensor(2.1752, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5661, grad_fn=) tensor(6.7718, grad_fn=)\n", "r0 after lock-down tensor(0.8358, grad_fn=) tensor(1.0228, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4000, grad_fn=) tensor(6.7201, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0045, grad_fn=) tensor(2.1624, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5232, grad_fn=) tensor(6.7246, grad_fn=)\n", "r0 after lock-down tensor(0.8339, grad_fn=) tensor(1.0184, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4189, grad_fn=) tensor(6.7346, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0141, grad_fn=) tensor(2.1695, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.4882, grad_fn=) tensor(6.6855, grad_fn=)\n", "r0 after lock-down tensor(0.8312, grad_fn=) tensor(1.0121, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4272, grad_fn=) tensor(6.7375, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0195, grad_fn=) tensor(2.1729, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5329, grad_fn=) tensor(6.7266, grad_fn=)\n", "r0 after lock-down tensor(0.8390, grad_fn=) tensor(1.0167, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4261, grad_fn=) tensor(6.7319, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0209, grad_fn=) tensor(2.1722, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5190, grad_fn=) tensor(6.7086, grad_fn=)\n", "r0 after lock-down tensor(0.8402, grad_fn=) tensor(1.0146, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4349, grad_fn=) tensor(6.7356, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0256, grad_fn=) tensor(2.1741, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5088, grad_fn=) tensor(6.6947, grad_fn=)\n", "r0 after lock-down tensor(0.8418, grad_fn=) tensor(1.0139, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4140, grad_fn=) tensor(6.7091, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0217, grad_fn=) tensor(2.1674, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5158, grad_fn=) tensor(6.6983, grad_fn=)\n", "r0 after lock-down tensor(0.8432, grad_fn=) tensor(1.0123, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4098, grad_fn=) tensor(6.6986, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0205, grad_fn=) tensor(2.1629, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5168, grad_fn=) tensor(6.6962, grad_fn=)\n", "r0 after lock-down tensor(0.8434, grad_fn=) tensor(1.0100, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4146, grad_fn=) tensor(6.6983, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0238, grad_fn=) tensor(2.1639, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5213, grad_fn=) tensor(6.6973, grad_fn=)\n", "r0 after lock-down tensor(0.8448, grad_fn=) tensor(1.0077, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4334, grad_fn=) tensor(6.7130, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0314, grad_fn=) tensor(2.1690, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5305, grad_fn=) tensor(6.7036, grad_fn=)\n", "r0 after lock-down tensor(0.8445, grad_fn=) tensor(1.0048, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4136, grad_fn=) tensor(6.6892, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0268, grad_fn=) tensor(2.1620, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5169, grad_fn=) tensor(6.6865, grad_fn=)\n", "r0 after lock-down tensor(0.8441, grad_fn=) tensor(1.0012, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4230, grad_fn=) tensor(6.6941, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0303, grad_fn=) tensor(2.1634, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5137, grad_fn=) tensor(6.6799, grad_fn=)\n", "r0 after lock-down tensor(0.8450, grad_fn=) tensor(0.9989, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4164, grad_fn=) tensor(6.6838, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0275, grad_fn=) tensor(2.1581, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5465, grad_fn=) tensor(6.7103, grad_fn=)\n", "r0 after lock-down tensor(0.8514, grad_fn=) tensor(1.0035, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4263, grad_fn=) tensor(6.6891, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0323, grad_fn=) tensor(2.1607, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5279, grad_fn=) tensor(6.6887, grad_fn=)\n", "r0 after lock-down tensor(0.8521, grad_fn=) tensor(1.0020, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4255, grad_fn=) tensor(6.6839, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0340, grad_fn=) tensor(2.1602, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5370, grad_fn=) tensor(6.6943, grad_fn=)\n", "r0 after lock-down tensor(0.8561, grad_fn=) tensor(1.0031, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4120, grad_fn=) tensor(6.6648, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0307, grad_fn=) tensor(2.1542, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5333, grad_fn=) tensor(6.6879, grad_fn=)\n", "r0 after lock-down tensor(0.8560, grad_fn=) tensor(0.9995, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4214, grad_fn=) tensor(6.6693, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0343, grad_fn=) tensor(2.1550, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5453, grad_fn=) tensor(6.6969, grad_fn=)\n", "r0 after lock-down tensor(0.8585, grad_fn=) tensor(1.0000, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4125, grad_fn=) tensor(6.6561, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0316, grad_fn=) tensor(2.1499, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5277, grad_fn=) tensor(6.6766, grad_fn=)\n", "r0 after lock-down tensor(0.8575, grad_fn=) tensor(0.9962, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4448, grad_fn=) tensor(6.6849, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0422, grad_fn=) tensor(2.1588, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5608, grad_fn=) tensor(6.7076, grad_fn=)\n", "r0 after lock-down tensor(0.8609, grad_fn=) tensor(0.9983, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4604, grad_fn=) tensor(6.6969, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0480, grad_fn=) tensor(2.1627, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5200, grad_fn=) tensor(6.6637, grad_fn=)\n", "r0 after lock-down tensor(0.8583, grad_fn=) tensor(0.9934, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4436, grad_fn=) tensor(6.6761, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0438, grad_fn=) tensor(2.1563, grad_fn=)\n", "inf\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "r0 before lock-down tensor(6.5266, grad_fn=) tensor(6.6676, grad_fn=)\n", "r0 after lock-down tensor(0.8624, grad_fn=) tensor(0.9951, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4384, grad_fn=) tensor(6.6676, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0441, grad_fn=) tensor(2.1547, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5331, grad_fn=) tensor(6.6712, grad_fn=)\n", "r0 after lock-down tensor(0.8657, grad_fn=) tensor(0.9958, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4421, grad_fn=) tensor(6.6674, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0453, grad_fn=) tensor(2.1539, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5428, grad_fn=) tensor(6.6785, grad_fn=)\n", "r0 after lock-down tensor(0.8674, grad_fn=) tensor(0.9949, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4578, grad_fn=) tensor(6.6793, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0504, grad_fn=) tensor(2.1573, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5399, grad_fn=) tensor(6.6730, grad_fn=)\n", "r0 after lock-down tensor(0.8695, grad_fn=) tensor(0.9947, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4522, grad_fn=) tensor(6.6708, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0478, grad_fn=) tensor(2.1527, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5428, grad_fn=) tensor(6.6735, grad_fn=)\n", "r0 after lock-down tensor(0.8680, grad_fn=) tensor(0.9909, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4602, grad_fn=) tensor(6.6761, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0512, grad_fn=) tensor(2.1545, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5887, grad_fn=) tensor(6.7176, grad_fn=)\n", "r0 after lock-down tensor(0.8755, grad_fn=) tensor(0.9969, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4525, grad_fn=) tensor(6.6650, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0499, grad_fn=) tensor(2.1516, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5495, grad_fn=) tensor(6.6756, grad_fn=)\n", "r0 after lock-down tensor(0.8738, grad_fn=) tensor(0.9928, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4600, grad_fn=) tensor(6.6687, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0534, grad_fn=) tensor(2.1538, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5569, grad_fn=) tensor(6.6814, grad_fn=)\n", "r0 after lock-down tensor(0.8753, grad_fn=) tensor(0.9930, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4363, grad_fn=) tensor(6.6415, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0483, grad_fn=) tensor(2.1471, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5408, grad_fn=) tensor(6.6630, grad_fn=)\n", "r0 after lock-down tensor(0.8725, grad_fn=) tensor(0.9877, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4705, grad_fn=) tensor(6.6725, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0582, grad_fn=) tensor(2.1553, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5592, grad_fn=) tensor(6.6801, grad_fn=)\n", "r0 after lock-down tensor(0.8743, grad_fn=) tensor(0.9880, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4510, grad_fn=) tensor(6.6493, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0536, grad_fn=) tensor(2.1493, grad_fn=)\n", "inf\n", "r0 before lock-down tensor(6.5583, grad_fn=) tensor(6.6773, grad_fn=)\n", "r0 after lock-down tensor(0.8742, grad_fn=) tensor(0.9860, grad_fn=)\n", "r0 mrs before lock-down tensor(6.4586, grad_fn=) tensor(6.6533, grad_fn=)\n", "r0 mrs after lock-down tensor(2.0568, grad_fn=) tensor(2.1509, grad_fn=)\n", "inf\n", "CPU times: user 23min 45s, sys: 2.65 s, total: 23min 48s\n", "Wall time: 23min 50s\n" ] } ], "source": [ "%%time\n", "losses, a,b = [], [], []\n", "num_steps = 20000\n", "for t in range(num_steps):\n", " #print(svi.step(guess_R0, guess_R0_RMS))\n", " #print('r0 = ',pyro.param(\"a_r0\"), pyro.param(\"b_r0\"))\n", " losses.append(svi.step())\n", " if t % 100 == 0:\n", " print('r0 before lock-down ',pyro.param(\"a_r0_1\"), pyro.param(\"a_r0_1\") + pyro.param(\"b_r0_1\"))\n", " print('r0 after lock-down ', pyro.param(\"a_r0_2\"), pyro.param(\"a_r0_2\") + pyro.param(\"b_r0_2\"))\n", " print('r0 mrs before lock-down ',pyro.param(\"a_r0_mrs_1\"), pyro.param(\"a_r0_mrs_1\") + pyro.param(\"b_r0_mrs_1\"))\n", " print('r0 mrs after lock-down ', pyro.param(\"a_r0_mrs_2\"), pyro.param(\"a_r0_mrs_2\") + pyro.param(\"b_r0_mrs_2\"))\n", " print(losses[-1])" ] }, { "cell_type": "code", "execution_count": 230, "metadata": { "scrolled": false }, "outputs": [], "source": [ "n_points = 100\n", "for i in range(n_points):\n", " if mrs:\n", " s_T, e_T, i_T, h_T, l_T, m_T, m_mrs_T, r_T, r0_T, r0_mrs_T = svi.guide(True)\n", " else:\n", " s_T, e_T, i_T, h_T, l_T, m_T, r_T, r0_T = svi.guide(True)\n", " hospi_T = np.array(list(map(lambda x: (x[0] + x[1]).detach().numpy(), zip(h_T, l_T)))).reshape(-1, 1)\n", " l_T = np.array(list(map(lambda x: x.detach().numpy(), l_T))).reshape(-1, 1)\n", " m_T = np.array(list(map(lambda x: x.detach().numpy(), m_T))).reshape(-1, 1)\n", " m_mrs_T = np.array(list(map(lambda x: x.detach().numpy(), m_mrs_T))).reshape(-1, 1)\n", " r0_T = np.array(list(map(lambda x: x.detach().numpy(), r0_T))).reshape(-1, 1)\n", " r0_mrs_T = np.array(list(map(lambda x: x.detach().numpy(), r0_mrs_T))).reshape(-1, 1)\n", "\n", " if i == 0:\n", " h = hospi_T\n", " l = l_T\n", " m = m_T\n", " m_mrs = m_mrs_T\n", " r0, r0_mrs = r0_T, r0_mrs_T\n", " else:\n", " h = np.append(h, hospi_T, axis=1)\n", " l = np.append(l, l_T, axis=1)\n", " m = np.append(m, m_T, axis=1)\n", " m_mrs = np.append(m_mrs, m_mrs_T, axis=1)\n", " r0_mrs = np.append(r0_mrs, r0_mrs_T, axis=1)\n", " r0 = np.append(r0, r0_T, axis=1)\n" ] }, { "cell_type": "code", "execution_count": 239, "metadata": { "scrolled": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def get_percentiles(arr, color=None,ax=None, label=None):\n", " arr_50 = np.percentile(arr, 50, axis=1)\n", " arr_5 = np.percentile(arr, 5, axis=1)\n", " arr_95 = np.percentile(arr, 95, axis=1)\n", " if color is not None:\n", " ax.plot(data_df['date'][:len(arr)], arr_50, c=color, label=label)\n", " ax.fill_between(data_df['date'][:len(arr)], arr_5, arr_95, color=color, alpha=0.2)\n", " return arr_50, arr_5, arr_95\n", "\n", "seaborn.set()\n", "fig, ax = plt.subplots(2, 1, figsize=(20, 10), sharex=True, gridspec_kw={\"height_ratios\": (4, 1)})\n", "\n", "ax[0].plot(data_df['date'], data_df['n_hospitalized'], 'x', color='b', label='Hospitalized')\n", "\n", "ax[0].plot(data_df['date'], data_df['n_icu'], 'x', color='g', label='ICUs')\n", "ax[0].plot(data_df['date'], data_df['n_deaths'] * frac_dh().item(), 'x', color='r', label='Deaths in hospital')\n", "ax[0].plot(data_df['date'], data_df['n_deaths']*(1-frac_dh().item()), 'x', color='orange', label='Deaths in MRS')\n", "\n", "get_percentiles(h, 'b', ax[0])\n", "get_percentiles(l, 'g', ax[0])\n", "get_percentiles(m, 'r', ax[0])\n", "get_percentiles(m_mrs, 'orange', ax[0])\n", "\n", "get_percentiles(r0_mrs, 'yellow', ax[1], \"R0 in mrs\")\n", "get_percentiles(r0, 'brown', ax[1], 'R0')\n", "\n", "ax[0].axvline(data_df['date'][date_r0_switch], 0, 8500, label='Lockdown', c='black', alpha=.8, linestyle ='--')\n", "ax[0].axvline(data_df['date'][date_r0_switch_mrs], 0, 8500, label='MRS forbidden', c='black', alpha=.8, linestyle=':')\n", "ax[1].axvline(data_df['date'][date_r0_switch], 0, 8500, label='Lockdown', c='black', alpha=.8, linestyle ='--')\n", "ax[1].axvline(data_df['date'][date_r0_switch_mrs], 0, 8500, label='MRS forbidden', c='black', alpha=.8, linestyle=':')\n", "\n", "ax[0].legend()\n", "ax[1].legend()\n", "\n", "plt.savefig('pyro_SEIR.png')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }