{ "cells": [ { "cell_type": "markdown", "id": "several-belly", "metadata": {}, "source": [ "# Stochastic Variational Inference\n", "\n", "Implementation of Stochastic Variational Inference (SVI) using [PyTorch](https://pytorch.org/), for the purpose of uncertainty quantification.\n", "\n", "We'll consider the [OLS Regression Challenge](https://data.world/nrippner/ols-regression-challenge), which aims at predicting cancer mortality rates for US counties.\n", "\n", "## Data Dictionary\n", "\n", "* **TARGET_deathRate**: Dependent variable. Mean per capita (100,000) cancer mortalities (a)\n", "* **avgAnnCount**: Mean number of reported cases of cancer diagnosed annually (a)\n", "* **avgDeathsPerYear**: Mean number of reported mortalities due to cancer (a)\n", "* **incidenceRate**: Mean per capita (100,000) cancer diagoses (a)\n", "* **medianIncome**: Median income per county (b)\n", "* **popEst2015**: Population of county (b)\n", "* **povertyPercent**: Percent of populace in poverty (b)\n", "* **studyPerCap**: Per capita number of cancer-related clinical trials per county (a)\n", "* **binnedInc**: Median income per capita binned by decile (b)\n", "* **MedianAge**: Median age of county residents (b)\n", "* **MedianAgeMale**: Median age of male county residents (b)\n", "* **MedianAgeFemale**: Median age of female county residents (b)\n", "* **Geography**: County name (b)\n", "* **AvgHouseholdSize**: Mean household size of county (b)\n", "* **PercentMarried**: Percent of county residents who are married (b)\n", "* **PctNoHS18_24**: Percent of county residents ages 18-24 highest education attained: less than high school (b)\n", "* **PctHS18_24**: Percent of county residents ages 18-24 highest education attained: high school diploma (b)\n", "* **PctSomeCol18_24**: Percent of county residents ages 18-24 highest education attained: some college (b)\n", "* **PctBachDeg18_24**: Percent of county residents ages 18-24 highest education attained: bachelor's degree (b)\n", "* **PctHS25_Over**: Percent of county residents ages 25 and over highest education attained: high school diploma (b)\n", "* **PctBachDeg25_Over**: Percent of county residents ages 25 and over highest education attained: bachelor's degree (b)\n", "* **PctEmployed16_Over**: Percent of county residents ages 16 and over employed (b)\n", "* **PctUnemployed16_Over**: Percent of county residents ages 16 and over unemployed (b)\n", "* **PctPrivateCoverage**: Percent of county residents with private health coverage (b)\n", "* **PctPrivateCoverageAlone**: Percent of county residents with private health coverage alone (no public assistance) (b)\n", "* **PctEmpPrivCoverage**: Percent of county residents with employee-provided private health coverage (b)\n", "* **PctPublicCoverage**: Percent of county residents with government-provided health coverage (b)\n", "* **PctPubliceCoverageAlone**: Percent of county residents with government-provided health coverage alone (b)\n", "* **PctWhite**: Percent of county residents who identify as White (b)\n", "* **PctBlack**: Percent of county residents who identify as Black (b)\n", "* **PctAsian**: Percent of county residents who identify as Asian (b)\n", "* **PctOtherRace**: Percent of county residents who identify in a category which is not White, Black, or Asian (b)\n", "* **PctMarriedHouseholds**: Percent of married households (b)\n", "* **BirthRate**: Number of live births relative to number of women in county (b)\n", "\n", "Notes:\n", "* (a): years 2010-2016\n", "* (b): 2013 Census Estimates" ] }, { "cell_type": "code", "execution_count": 1, "id": "plastic-monte", "metadata": {}, "outputs": [], "source": [ "import os\n", "from os.path import join\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "from sklearn.metrics import r2_score\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LinearRegression\n", "from scipy.stats import binned_statistic\n", "\n", "cwd = os.getcwd()\n", "if cwd.endswith('notebook'):\n", " os.chdir('..')\n", " cwd = os.getcwd()" ] }, { "cell_type": "code", "execution_count": 2, "id": "silent-chair", "metadata": {}, "outputs": [], "source": [ "sns.set(palette='colorblind', font_scale=1.3)\n", "palette = sns.color_palette()" ] }, { "cell_type": "code", "execution_count": 3, "id": "enclosed-strain", "metadata": {}, "outputs": [], "source": [ "seed = 444\n", "np.random.seed(seed);\n", "torch.manual_seed(seed);\n", "torch.set_default_dtype(torch.float64)" ] }, { "cell_type": "markdown", "id": "attempted-general", "metadata": {}, "source": [ "## Dataset\n", "\n", "### Load" ] }, { "cell_type": "code", "execution_count": 4, "id": "eleven-sound", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | avgAnnCount | \n", "avgDeathsPerYear | \n", "TARGET_deathRate | \n", "incidenceRate | \n", "medIncome | \n", "popEst2015 | \n", "povertyPercent | \n", "studyPerCap | \n", "binnedInc | \n", "MedianAge | \n", "... | \n", "PctPrivateCoverageAlone | \n", "PctEmpPrivCoverage | \n", "PctPublicCoverage | \n", "PctPublicCoverageAlone | \n", "PctWhite | \n", "PctBlack | \n", "PctAsian | \n", "PctOtherRace | \n", "PctMarriedHouseholds | \n", "BirthRate | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1397.0 | \n", "469 | \n", "164.9 | \n", "489.8 | \n", "61898 | \n", "260131 | \n", "11.2 | \n", "499.748204 | \n", "(61494.5, 125635] | \n", "39.3 | \n", "... | \n", "NaN | \n", "41.6 | \n", "32.9 | \n", "14.0 | \n", "81.780529 | \n", "2.594728 | \n", "4.821857 | \n", "1.843479 | \n", "52.856076 | \n", "6.118831 | \n", "
1 | \n", "173.0 | \n", "70 | \n", "161.3 | \n", "411.6 | \n", "48127 | \n", "43269 | \n", "18.6 | \n", "23.111234 | \n", "(48021.6, 51046.4] | \n", "33.0 | \n", "... | \n", "53.8 | \n", "43.6 | \n", "31.1 | \n", "15.3 | \n", "89.228509 | \n", "0.969102 | \n", "2.246233 | \n", "3.741352 | \n", "45.372500 | \n", "4.333096 | \n", "
2 | \n", "102.0 | \n", "50 | \n", "174.7 | \n", "349.7 | \n", "49348 | \n", "21026 | \n", "14.6 | \n", "47.560164 | \n", "(48021.6, 51046.4] | \n", "45.0 | \n", "... | \n", "43.5 | \n", "34.9 | \n", "42.1 | \n", "21.1 | \n", "90.922190 | \n", "0.739673 | \n", "0.465898 | \n", "2.747358 | \n", "54.444868 | \n", "3.729488 | \n", "
3 | \n", "427.0 | \n", "202 | \n", "194.8 | \n", "430.4 | \n", "44243 | \n", "75882 | \n", "17.1 | \n", "342.637253 | \n", "(42724.4, 45201] | \n", "42.8 | \n", "... | \n", "40.3 | \n", "35.0 | \n", "45.3 | \n", "25.0 | \n", "91.744686 | \n", "0.782626 | \n", "1.161359 | \n", "1.362643 | \n", "51.021514 | \n", "4.603841 | \n", "
4 | \n", "57.0 | \n", "26 | \n", "144.4 | \n", "350.1 | \n", "49955 | \n", "10321 | \n", "12.5 | \n", "0.000000 | \n", "(48021.6, 51046.4] | \n", "48.3 | \n", "... | \n", "43.9 | \n", "35.1 | \n", "44.0 | \n", "22.7 | \n", "94.104024 | \n", "0.270192 | \n", "0.665830 | \n", "0.492135 | \n", "54.027460 | \n", "6.796657 | \n", "
5 rows × 34 columns
\n", "\n", " | avgAnnCount | \n", "avgDeathsPerYear | \n", "TARGET_deathRate | \n", "incidenceRate | \n", "medIncome | \n", "popEst2015 | \n", "povertyPercent | \n", "studyPerCap | \n", "MedianAge | \n", "MedianAgeMale | \n", "... | \n", "PctPrivateCoverageAlone | \n", "PctEmpPrivCoverage | \n", "PctPublicCoverage | \n", "PctPublicCoverageAlone | \n", "PctWhite | \n", "PctBlack | \n", "PctAsian | \n", "PctOtherRace | \n", "PctMarriedHouseholds | \n", "BirthRate | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3.047000e+03 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "... | \n", "2438.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "3047.000000 | \n", "
mean | \n", "606.338544 | \n", "185.965868 | \n", "178.664063 | \n", "448.268586 | \n", "47063.281917 | \n", "1.026374e+05 | \n", "16.878175 | \n", "155.399415 | \n", "45.272333 | \n", "39.570725 | \n", "... | \n", "48.453774 | \n", "41.196324 | \n", "36.252642 | \n", "19.240072 | \n", "83.645286 | \n", "9.107978 | \n", "1.253965 | \n", "1.983523 | \n", "51.243872 | \n", "5.640306 | \n", "
std | \n", "1416.356223 | \n", "504.134286 | \n", "27.751511 | \n", "54.560733 | \n", "12040.090836 | \n", "3.290592e+05 | \n", "6.409087 | \n", "529.628366 | \n", "45.304480 | \n", "5.226017 | \n", "... | \n", "10.083006 | \n", "9.447687 | \n", "7.841741 | \n", "6.113041 | \n", "16.380025 | \n", "14.534538 | \n", "2.610276 | \n", "3.517710 | \n", "6.572814 | \n", "1.985816 | \n", "
min | \n", "6.000000 | \n", "3.000000 | \n", "59.700000 | \n", "201.300000 | \n", "22640.000000 | \n", "8.270000e+02 | \n", "3.200000 | \n", "0.000000 | \n", "22.300000 | \n", "22.400000 | \n", "... | \n", "15.700000 | \n", "13.500000 | \n", "11.200000 | \n", "2.600000 | \n", "10.199155 | \n", "0.000000 | \n", "0.000000 | \n", "0.000000 | \n", "22.992490 | \n", "0.000000 | \n", "
25% | \n", "76.000000 | \n", "28.000000 | \n", "161.200000 | \n", "420.300000 | \n", "38882.500000 | \n", "1.168400e+04 | \n", "12.150000 | \n", "0.000000 | \n", "37.700000 | \n", "36.350000 | \n", "... | \n", "41.000000 | \n", "34.500000 | \n", "30.900000 | \n", "14.850000 | \n", "77.296180 | \n", "0.620675 | \n", "0.254199 | \n", "0.295172 | \n", "47.763063 | \n", "4.521419 | \n", "
50% | \n", "171.000000 | \n", "61.000000 | \n", "178.100000 | \n", "453.549422 | \n", "45207.000000 | \n", "2.664300e+04 | \n", "15.900000 | \n", "0.000000 | \n", "41.000000 | \n", "39.600000 | \n", "... | \n", "48.700000 | \n", "41.100000 | \n", "36.300000 | \n", "18.800000 | \n", "90.059774 | \n", "2.247576 | \n", "0.549812 | \n", "0.826185 | \n", "51.669941 | \n", "5.381478 | \n", "
75% | \n", "518.000000 | \n", "149.000000 | \n", "195.200000 | \n", "480.850000 | \n", "52492.000000 | \n", "6.867100e+04 | \n", "20.400000 | \n", "83.650776 | \n", "44.000000 | \n", "42.500000 | \n", "... | \n", "55.600000 | \n", "47.700000 | \n", "41.550000 | \n", "23.100000 | \n", "95.451693 | \n", "10.509732 | \n", "1.221037 | \n", "2.177960 | \n", "55.395132 | \n", "6.493677 | \n", "
max | \n", "38150.000000 | \n", "14010.000000 | \n", "362.800000 | \n", "1206.900000 | \n", "125635.000000 | \n", "1.017029e+07 | \n", "47.400000 | \n", "9762.308998 | \n", "624.000000 | \n", "64.700000 | \n", "... | \n", "78.900000 | \n", "70.700000 | \n", "65.100000 | \n", "46.600000 | \n", "100.000000 | \n", "85.947799 | \n", "42.619425 | \n", "41.930251 | \n", "78.075397 | \n", "21.326165 | \n", "
8 rows × 32 columns
\n", "