{ "cells": [ { "cell_type": "markdown", "source": [ "# Quickstart\n", "\n", "The partial and full confound tests can characterize whether a predictive model\n", "(e.g. a machine learning model) is biased by the confounder variable.\n", "\n", "In research using predictive modelling techniques, confounder-bias is often investigated\n", "(if investigated at all) by testing the association between the confounder variable and\n", "the predictive values.\n", "However, the significance of this association does not necessarily imply a significant\n", "confound-bias of the model, especially if the confounder is also associated to the\n", "true target variable. In this case, namely, the model still might not be directly\n", "driven by the confounder, i.e. the dependence of the predictions on the confounder\n", "can be explained solely by the confounder-target association.\n", "Put simply, this is what is tested by the proposed partial confounder test.\n", "\n", "Here we will apply the `partial confound test` on two simulated datasets:\n", "- H0: null-hypothesis dataset with no confounder bias,\n", " i.e. conditional independence between the predicted values\n", " and the confounder variable, given the observed target variable.\n", " Note that the (unconditional) association between the prediction and the target are significant\n", " but - according to the H0 of the partial confoudner test - can be fully explained by the association\n", " between the target and the confounder.\n", "- H1: alternative hypothesis with an explicit confounder bias. Here, the association between\n", " the predictions and the confounder is stronger than what could follow form the association\n", " between the target and the confounder.\n", "\n", "\n", "##### Import the necessary packages" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from mlconfound.stats import partial_confound_test\n", "from mlconfound.simulate import simulate_y_c_yhat\n", "from mlconfound.plot import plot_null_dist, plot_graph\n", "\n", "import pandas as pd\n", "import seaborn as sns\n", "sns.set_style(\"whitegrid\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "##### H0 simulations\n", "\n", "Next, we simulate some data from the null hypothesis.\n", "\n", "The simulation samples `y` and `c` (`n` datapoints each) from the multivariate normal distribution,\n", "so that their correlation is `cov_y_c`.\n", "Next, a 'prediction' of y is simulated like this:\n", "```\n", "yhat = y_ratio_yhat * y + c_ratio_yhat * c + e\n", "```\n", "where `e` is random standard-Gaussian noise.\n", "\n", "For the H0 simulation, `c_ratio_yhat` is set to zero." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "H0_y, H0_c, H0_yhat = simulate_y_c_yhat(cov_y_c=0.3,\n", " y_ratio_yhat=0.5, c_ratio_yhat=0,\n", " n=500, random_state=42)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "##### Partial confound test\n", "\n", "Now let's perform the partial confounder tests on this data.\n", "The number of permutations and the steps in the Markov-chain Monte Carlo process set to the default values\n", "(1000 and 50, respectively).\n", "Increase the number of permutations for more accurate p-value estimates.\n", "\n", "The random seed is set for reproducible results.\n", "The flag `return_null_dist` is set so that the full permutation-based null distribution is returned,\n", "e.g. for plotting purposes.\n", "\n", "The pandas dataframe is solely created for \"pretty-printing\" the results." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Permuting: 100%|██████████| 1000/1000 [00:02<00:00, 457.42it/s]\n" ] }, { "data": { "text/plain": " p ci lower ci upper R2(y,c) R2(y^,c) R2(y,y^)\n0 0.798 0.771754 0.822478 0.094132 0.031789 0.478693", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
pci lowerci upperR2(y,c)R2(y^,c)R2(y,y^)
00.7980.7717540.8224780.0941320.0317890.478693
\n
" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ret=partial_confound_test(H0_y, H0_yhat, H0_c, return_null_dist=True,\n", " random_state=42)\n", "#pretty print results\n", "pd.DataFrame({\n", " 'p' : [ret.p],\n", " 'ci lower' : [ret.p_ci[0]],\n", " 'ci upper' : [ret.p_ci[1]],\n", " 'R2(y,c)' : [ret.r2_y_c],\n", " 'R2(y^,c)' : [ret.r2_yhat_c],\n", " 'R2(y,y^)' : [ret.r2_y_yhat],\n", "})\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "Let's use the built-in plot functions of the package `mlconfound` for a graphical representation of the results." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": "
", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_null_dist(ret)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "The histogram shows the $R^2$ values between the predictions and the permuted confounder variable\n", "(conditional permutations). The red line indicates that the unpermuted $R^2$ is not \"extreme\",\n", "i.e. we have no evidence against the null ($p=0.8$)." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "", "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\nc\n\nc\n\n\n\ny\n\ny\n\n\n\nc--y\n\n0.094\n\n\n\nyhat\n\n\n\n\n\nc--yhat\n\n0.032 (p=0.8)\n\n\n\ny--yhat\n\n0.479\n\n\n\n" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot_graph(ret)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "The graph shows the unconditional $R^2$ values across the target $y$, confounder $c$ and predictions $\\hat{y}$." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "##### H1 simulations and test\n", "\n", "No let's apply the partial confounder test for H1, that is for a confounded model." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Permuting: 100%|██████████| 1000/1000 [00:01<00:00, 759.92it/s]\n" ] }, { "data": { "text/plain": " p ci lower ci upper R2(y,c) R2(y^,c) R2(y,y^)\n0 0.015 0.008419 0.02462 0.094132 0.079504 0.390694", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
pci lowerci upperR2(y,c)R2(y^,c)R2(y,y^)
00.0150.0084190.024620.0941320.0795040.390694
\n
" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "H1_y, H1_c, H1_yhat = simulate_y_c_yhat(cov_y_c=0.3,\n", " y_ratio_yhat=0.4, c_ratio_yhat=0.1,\n", " n=500, random_state=42)\n", "ret=partial_confound_test(H1_y, H1_yhat, H1_c, num_perms=1000, return_null_dist=True,\n", " random_state=42, n_jobs=-1)\n", "\n", "#pretty print results\n", "pd.DataFrame({\n", " 'p' : [ret.p],\n", " 'ci lower' : [ret.p_ci[0]],\n", " 'ci upper' : [ret.p_ci[1]],\n", " 'R2(y,c)' : [ret.r2_y_c],\n", " 'R2(y^,c)' : [ret.r2_yhat_c],\n", " 'R2(y,y^)' : [ret.r2_y_yhat],\n", "})" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "data": { "text/plain": "", "image/svg+xml": "\n\n\n\n\n\n%3\n\n\n\nc\n\nage\n\n\n\ny\n\nIQ\n\n\n\nc--y\n\n0.094\n\n\n\nyhat\n\nprediction\n\n\n\nc--yhat\n\n0.08 (p=0.02*)\n\n\n\ny--yhat\n\n0.391\n\n\n\n" }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": "
", "image/png": "\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_null_dist(ret)\n", "# The labels on the graph plot can be customized:\n", "plot_graph(ret, y_name='IQ', yhat_name='prediction', c_name='age', outfile_base='example')" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "The low p-value provides evidence against the null hypothesis of\n", "$y$ being conditionally independent on $c$ given $y$ and indicates that the model predictions are biased.\n", "\n", "-----------------------------------------------------------------------------\n", "Note\n", "\n", "For parametric corrections for multiple comparisons\n", "(e.g. false discover##### Partial confound testy rate in case of testing many confounders),\n", "permutation based p-values must be adjusted if they are zero.\n", "A decent option could be in this case to use the upper binomial confidence limit (`p_ci[1]`), instead.\n", "\n", "-----------------------------------------------------------------------------\n", "\n", "### References\n", "*Tamas Spisak, A conditional permutation-based approach to test confounder effect and center-bias in\n", "machine learning models, in prep, 2021.*" ], "metadata": { "collapsed": false } } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }