{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Running a simulator using existing data\n",
"Consider the case when input data already exists, and that data already has a causal structure.\n",
"We would like to simulate treatment assignment and outcomes based on this data.\n",
"\n",
"### Initialize the data\n",
"First we load the desired data into a pandas DataFrame:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from causallib.datasets import load_nhefs\n",
"from causallib.datasets import CausalSimulator\n",
"from causallib.datasets import generate_random_topology"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"data = load_nhefs()\n",
"X_given = data.X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"say we want to create three more variables: covariate, treatment and outcome.\n",
"This will be a bit difficult to hardwire a graph with many variables, so lets use the random topology generator:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"topology, var_types = generate_random_topology(n_covariates=1, p=0.4,\n",
" n_treatments=1, n_outcomes=1,\n",
" given_vars=X_given.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we create the simulator based on the variables topology:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"outcome_types = \"categorical\"\n",
"link_types = ['linear'] * len(var_types)\n",
"prob_categories = pd.Series(data=[[0.5, 0.5] if typ in [\"treatment\", \"outcome\"] else None for typ in var_types],\n",
" index=var_types.index)\n",
"treatment_methods = \"gaussian\"\n",
"snr = 0.9\n",
"treatment_importance = 0.8\n",
"effect_sizes = None\n",
"sim = CausalSimulator(topology=topology.values, prob_categories=prob_categories,\n",
" link_types=link_types, snr=snr, var_types=var_types,\n",
" treatment_importances=treatment_importance,\n",
" outcome_types=outcome_types,\n",
" treatment_methods=treatment_methods,\n",
" effect_sizes=effect_sizes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now in order to generate data based on the given data we need to specify:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"X, prop, y = sim.generate_data(X_given=X_given)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Format the data for training and save it\n",
"\n",
"Now that we generated some data, we can format it so it would be easier to train and validate:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"observed_set, validation_set = sim.format_for_training(X, prop, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"observed_set is the observed dataset (excluding hidden variables)validation_set is for validation purposes - it has the counterfactuals, the treatments assignment and the propensity for every sample.\n",
"You can save the datasets into csv:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1566, 19)\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" x_18 | \n",
" x_active_1 | \n",
" x_active_2 | \n",
" x_age | \n",
" x_age^2 | \n",
" x_education_2 | \n",
" x_education_3 | \n",
" x_education_4 | \n",
" x_education_5 | \n",
" x_exercise_1 | \n",
" x_exercise_2 | \n",
" x_race | \n",
" x_sex | \n",
" x_smokeintensity | \n",
" x_smokeintensity^2 | \n",
" x_smokeyrs | \n",
" x_smokeyrs^2 | \n",
" x_wt71 | \n",
" x_wt71^2 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 153.760252 | \n",
" 0 | \n",
" 0 | \n",
" 42 | \n",
" 1764 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 30 | \n",
" 900 | \n",
" 29 | \n",
" 841 | \n",
" 79.04 | \n",
" 6247.3216 | \n",
"
\n",
" \n",
" 1 | \n",
" 94.762203 | \n",
" 0 | \n",
" 0 | \n",
" 36 | \n",
" 1296 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 20 | \n",
" 400 | \n",
" 24 | \n",
" 576 | \n",
" 58.63 | \n",
" 3437.4769 | \n",
"
\n",
" \n",
" 2 | \n",
" 669.486191 | \n",
" 0 | \n",
" 0 | \n",
" 56 | \n",
" 3136 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 20 | \n",
" 400 | \n",
" 26 | \n",
" 676 | \n",
" 56.81 | \n",
" 3227.3761 | \n",
"
\n",
" \n",
" 3 | \n",
" -865.113582 | \n",
" 1 | \n",
" 0 | \n",
" 68 | \n",
" 4624 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 3 | \n",
" 9 | \n",
" 53 | \n",
" 2809 | \n",
" 59.42 | \n",
" 3530.7364 | \n",
"
\n",
" \n",
" 4 | \n",
" 634.638630 | \n",
" 1 | \n",
" 0 | \n",
" 40 | \n",
" 1600 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 20 | \n",
" 400 | \n",
" 19 | \n",
" 361 | \n",
" 87.09 | \n",
" 7584.6681 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" x_18 x_active_1 x_active_2 x_age x_age^2 x_education_2 \\\n",
"0 153.760252 0 0 42 1764 0 \n",
"1 94.762203 0 0 36 1296 1 \n",
"2 669.486191 0 0 56 3136 1 \n",
"3 -865.113582 1 0 68 4624 0 \n",
"4 634.638630 1 0 40 1600 1 \n",
"\n",
" x_education_3 x_education_4 x_education_5 x_exercise_1 x_exercise_2 \\\n",
"0 0 0 0 0 1 \n",
"1 0 0 0 0 0 \n",
"2 0 0 0 0 1 \n",
"3 0 0 0 0 1 \n",
"4 0 0 0 1 0 \n",
"\n",
" x_race x_sex x_smokeintensity x_smokeintensity^2 x_smokeyrs \\\n",
"0 1 0 30 900 29 \n",
"1 0 0 20 400 24 \n",
"2 1 1 20 400 26 \n",
"3 1 0 3 9 53 \n",
"4 0 0 20 400 19 \n",
"\n",
" x_smokeyrs^2 x_wt71 x_wt71^2 \n",
"0 841 79.04 6247.3216 \n",
"1 576 58.63 3437.4769 \n",
"2 676 56.81 3227.3761 \n",
"3 2809 59.42 3530.7364 \n",
"4 361 87.09 7584.6681 "
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"covariates = observed_set.loc[:, observed_set.columns.str.startswith(\"x_\")]\n",
"print(covariates.shape)\n",
"covariates.head()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1566, 2)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" t_19 | \n",
" y_20 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" t_19 y_20\n",
"0 0 0\n",
"1 0 1\n",
"2 0 1\n",
"3 1 1\n",
"4 1 0"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"treatment_outcome = observed_set.loc[:, (observed_set.columns.str.startswith(\"t_\") |\n",
" observed_set.columns.str.startswith(\"y_\"))]\n",
"print(treatment_outcome.shape)\n",
"treatment_outcome.head()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1566, 5)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" t_19 | \n",
" p_19_0 | \n",
" p_19_1 | \n",
" cf_20_0 | \n",
" cf_20_1 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 2 | \n",
" 0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" t_19 p_19_0 p_19_1 cf_20_0 cf_20_1\n",
"0 0 1.0 0.0 0 0\n",
"1 0 1.0 0.0 1 1\n",
"2 0 1.0 0.0 1 1\n",
"3 1 1.0 0.0 1 1\n",
"4 1 1.0 0.0 0 0"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(validation_set.shape)\n",
"validation_set.head()"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"sim.to_csv(observed_set, 'training_set.csv')\n",
"sim.to_csv(validation_set, 'validation_set.csv')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 1
}