{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# uncomment and install dependencies before continuing\n", "# !pip install --upgrade inFairness requests tqdm" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mayank/opt/anaconda3/envs/infairness/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "from tqdm.auto import tqdm\n", "\n", "from inFairness.fairalgo import SenSeI\n", "from inFairness import distances\n", "from inFairness.auditor import SenSRAuditor, SenSeIAuditor\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import data\n", "import metrics" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class AdultDataset(Dataset):\n", " def __init__(self, data, labels):\n", " self.data = data\n", " self.labels = labels\n", "\n", " def __getitem__(self, idx):\n", " data = self.data[idx]\n", " label = self.labels[idx]\n", " return data, label\n", " \n", " def __len__(self):\n", " return len(self.labels)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agecapital-gaincapital-losseducation-numhours-per-weekmarital-status_Divorcedmarital-status_Married-AF-spousemarital-status_Married-civ-spousemarital-status_Married-spouse-absentmarital-status_Never-married...relationship_Own-childrelationship_Unmarriedrelationship_Wifeworkclass_Federal-govworkclass_Local-govworkclass_Privateworkclass_Self-emp-incworkclass_Self-emp-not-incworkclass_State-govworkclass_Without-pay
00.409331-0.14652-0.218253-1.613806-0.49677000001...0100010000
1-1.104187-0.14652-0.218253-0.050064-1.74176400001...1000010000
21.393118-0.14652-0.218253-0.4409992.57421400100...0000100000
3-0.423104-0.14652-0.218253-0.4409991.16322100100...0000010000
4-0.877159-0.14652-0.2182531.1227430.74822400100...0000000100
\n", "

5 rows × 39 columns

\n", "
" ], "text/plain": [ " age capital-gain capital-loss education-num hours-per-week \\\n", "0 0.409331 -0.14652 -0.218253 -1.613806 -0.496770 \n", "1 -1.104187 -0.14652 -0.218253 -0.050064 -1.741764 \n", "2 1.393118 -0.14652 -0.218253 -0.440999 2.574214 \n", "3 -0.423104 -0.14652 -0.218253 -0.440999 1.163221 \n", "4 -0.877159 -0.14652 -0.218253 1.122743 0.748224 \n", "\n", " marital-status_Divorced marital-status_Married-AF-spouse \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " marital-status_Married-civ-spouse marital-status_Married-spouse-absent \\\n", "0 0 0 \n", "1 0 0 \n", "2 1 0 \n", "3 1 0 \n", "4 1 0 \n", "\n", " marital-status_Never-married ... relationship_Own-child \\\n", "0 1 ... 0 \n", "1 1 ... 1 \n", "2 0 ... 0 \n", "3 0 ... 0 \n", "4 0 ... 0 \n", "\n", " relationship_Unmarried relationship_Wife workclass_Federal-gov \\\n", "0 1 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " workclass_Local-gov workclass_Private workclass_Self-emp-inc \\\n", "0 0 1 0 \n", "1 0 1 0 \n", "2 1 0 0 \n", "3 0 1 0 \n", "4 0 0 0 \n", "\n", " workclass_Self-emp-not-inc workclass_State-gov workclass_Without-pay \n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 1 0 0 \n", "\n", "[5 rows x 39 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df, test_df = data.load_data()\n", "\n", "X_train_df, Y_train_df = train_df\n", "X_test_df, Y_test_df = test_df\n", "\n", "# Let's drop the protected attributes from the training and test data and store them in a\n", "# separate dataframe that we'll use later to train the individually fair metric.\n", "protected_vars = ['race_White', 'sex_Male']\n", "\n", "X_protected_df = X_train_df[protected_vars]\n", "X_train_df = X_train_df.drop(columns=protected_vars)\n", "X_test_df = X_test_df.drop(columns=protected_vars)\n", "\n", "# Create test data with spouse variable flipped\n", "X_test_df_spouse_flipped = X_test_df.copy()\n", "X_test_df_spouse_flipped.relationship_Wife = 1 - X_test_df_spouse_flipped.relationship_Wife\n", "\n", "X_train_df.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "device = torch.device('cpu')\n", "\n", "# Convert all pandas dataframes to PyTorch tensors\n", "X_train, y_train = data.convert_df_to_tensor(X_train_df, Y_train_df)\n", "X_test, y_test = data.convert_df_to_tensor(X_test_df, Y_test_df)\n", "X_test_flip, y_test_flip = data.convert_df_to_tensor(X_test_df_spouse_flipped, Y_test_df)\n", "X_protected = torch.tensor(X_protected_df.values).float()\n", "\n", "# Create the training and testing dataset\n", "train_ds = AdultDataset(X_train, y_train)\n", "test_ds = AdultDataset(X_test, y_test)\n", "test_ds_flip = AdultDataset(X_test_flip, y_test_flip)\n", "\n", "# Create train and test dataloaders\n", "train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)\n", "test_dl = DataLoader(test_ds, batch_size=1000, shuffle=False)\n", "test_dl_flip = DataLoader(test_ds_flip, batch_size=1000, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Create a fully connected neural network\n", "\n", "class Model(nn.Module):\n", "\n", " def __init__(self, input_size, output_size):\n", "\n", " super().__init__()\n", " self.fc1 = nn.Linear(input_size, 100)\n", " self.fc2 = nn.Linear(100, 100)\n", " self.fcout = nn.Linear(100, output_size)\n", "\n", " def forward(self, x):\n", "\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fcout(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Standard training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "input_size = X_train.shape[1]\n", "output_size = 2\n", "\n", "network_standard = Model(input_size, output_size).to(device)\n", "optimizer = torch.optim.Adam(network_standard.parameters(), lr=1e-3)\n", "loss_fn = F.cross_entropy\n", "\n", "EPOCHS = 10" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:03<00:00, 2.57it/s]\n" ] } ], "source": [ "network_standard.train()\n", "\n", "for epoch in tqdm(range(EPOCHS)):\n", "\n", " for x, y in train_dl:\n", "\n", " x, y = x.to(device), y.to(device)\n", " optimizer.zero_grad()\n", " y_pred = network_standard(x).squeeze()\n", " loss = loss_fn(y_pred, y)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.855042040348053\n", "Balanced accuracy: 0.7806884556970295\n", "Spouse consistency: 0.9593100398053959\n" ] } ], "source": [ "accuracy = metrics.accuracy(network_standard, test_dl, device)\n", "balanced_acc = metrics.balanced_accuracy(network_standard, test_dl, device)\n", "spouse_consistency = metrics.spouse_consistency(network_standard, test_dl, test_dl_flip, device)\n", "\n", "print(f'Accuracy: {accuracy}')\n", "print(f'Balanced accuracy: {balanced_acc}')\n", "print(f'Spouse consistency: {spouse_consistency}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Individually fair training with LogReg fair metric" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "network_fair_LR = Model(input_size, output_size).to(device)\n", "optimizer = torch.optim.Adam(network_fair_LR.parameters(), lr=1e-3)\n", "lossfn = F.cross_entropy\n", "\n", "distance_x_LR = distances.LogisticRegSensitiveSubspace()\n", "distance_y = distances.SquaredEuclideanDistance()\n", "\n", "distance_x_LR.fit(X_train, data_SensitiveAttrs=X_protected)\n", "distance_y.fit(num_dims=output_size)\n", "\n", "distance_x_LR.to(device)\n", "distance_y.to(device)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "rho = 5.0\n", "eps = 0.1\n", "auditor_nsteps = 100\n", "auditor_lr = 1e-3\n", "\n", "fairalgo_LR = SenSeI(network_fair_LR, distance_x_LR, distance_y, lossfn, rho, eps, auditor_nsteps, auditor_lr)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [03:02<00:00, 18.29s/it]\n" ] } ], "source": [ "fairalgo_LR.train()\n", "\n", "for epoch in tqdm(range(EPOCHS)):\n", " for x, y in train_dl:\n", " x, y = x.to(device), y.to(device)\n", " optimizer.zero_grad()\n", " result = fairalgo_LR(x, y)\n", " result.loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8369084596633911\n", "Balanced accuracy: 0.7357549314737899\n", "Spouse consistency: 0.9998894294559929\n" ] } ], "source": [ "accuracy = metrics.accuracy(network_fair_LR, test_dl, device)\n", "balanced_acc = metrics.balanced_accuracy(network_fair_LR, test_dl, device)\n", "spouse_consistency = metrics.spouse_consistency(network_fair_LR, test_dl, test_dl_flip, device)\n", "\n", "print(f'Accuracy: {accuracy}')\n", "print(f'Balanced accuracy: {balanced_acc}')\n", "print(f'Spouse consistency: {spouse_consistency}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Individually fair training with EXPLORE metric" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mayank/Documents/[Projects]/open-source/inFairness/examples/adult-income-prediction/../../inFairness/distances/explore_distance.py:76: RuntimeWarning: overflow encountered in exp\n", " sclVec = 2.0 / (np.exp(diag) - 1)\n" ] } ], "source": [ "Y_gender = X_protected[:, -1]\n", "X1, X2, Y_pairs = data.create_data_pairs(X_train, y_train, Y_gender)\n", "\n", "distance_x_explore = distances.EXPLOREDistance()\n", "distance_x_explore.fit(X1, X2, Y_pairs, iters=1000, batchsize=10000)\n", "distance_x_explore.to(device)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "network_fair_explore = Model(input_size, output_size).to(device)\n", "optimizer = torch.optim.Adam(network_fair_explore.parameters(), lr=1e-3)\n", "lossfn = F.cross_entropy\n", "\n", "rho = 25.0\n", "eps = 0.1\n", "auditor_nsteps = 10\n", "auditor_lr = 1e-2\n", "\n", "fairalgo_explore = SenSeI(network_fair_explore, distance_x_explore, distance_y, lossfn, rho, eps, auditor_nsteps, auditor_lr)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10/10 [00:24<00:00, 2.42s/it]\n" ] } ], "source": [ "fairalgo_explore.train()\n", "\n", "for epoch in tqdm(range(EPOCHS)):\n", " for x, y in train_dl:\n", " x, y = x.to(device), y.to(device)\n", " optimizer.zero_grad()\n", " result = fairalgo_explore(x, y)\n", " result.loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8224236965179443\n", "Balanced accuracy: 0.6999390313607438\n", "Spouse consistency: 0.9997788589119858\n" ] } ], "source": [ "accuracy = metrics.accuracy(network_fair_explore, test_dl, device)\n", "balanced_acc = metrics.balanced_accuracy(network_fair_explore, test_dl, device)\n", "spouse_consistency = metrics.spouse_consistency(network_fair_explore, test_dl, test_dl_flip, device)\n", "\n", "print(f'Accuracy: {accuracy}')\n", "print(f'Balanced accuracy: {balanced_acc}')\n", "print(f'Spouse consistency: {spouse_consistency}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Let's now audit the three models and check for their individual fairness compliance" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mayank/Documents/[Projects]/open-source/inFairness/examples/adult-income-prediction/../../inFairness/auditor/auditor.py:54: RuntimeWarning: invalid value encountered in divide\n", " loss_ratio = np.divide(loss_vals_adversarial, loss_vals_original)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "====================================================================================================\n", "LR metric\n", "Loss ratio (Standard model) : 2.4822924476956905. Is model fair: False\n", "Loss ratio (fair model - LogReg metric) : 1.0421434064879227. Is model fair: True\n", "Loss ratio (fair model - EXPLORE metric) : 1.026998276114377. Is model fair: True\n", "----------------------------------------------------------------------------------------------------\n", "\t As signified by these numbers, the fair models are fairer than the standard model\n", "====================================================================================================\n" ] } ], "source": [ "# Auditing using the SenSR Auditor + LR metric\n", "\n", "audit_nsteps = 1000\n", "audit_lr = 0.1\n", "\n", "auditor_LR = SenSRAuditor(loss_fn=loss_fn, distance_x=distance_x_LR, num_steps=audit_nsteps, lr=audit_lr, max_noise=0.5, min_noise=-0.5)\n", "\n", "audit_result_stdmodel = auditor_LR.audit(network_standard, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "audit_result_fairmodel_LR = auditor_LR.audit(network_fair_LR, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "audit_result_fairmodel_explore = auditor_LR.audit(network_fair_explore, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "\n", "print(\"=\"*100)\n", "print(\"LR metric\")\n", "print(f\"Loss ratio (Standard model) : {audit_result_stdmodel.lower_bound}. Is model fair: {audit_result_stdmodel.is_model_fair}\")\n", "print(f\"Loss ratio (fair model - LogReg metric) : {audit_result_fairmodel_LR.lower_bound}. Is model fair: {audit_result_fairmodel_LR.is_model_fair}\")\n", "print(f\"Loss ratio (fair model - EXPLORE metric) : {audit_result_fairmodel_explore.lower_bound}. Is model fair: {audit_result_fairmodel_explore.is_model_fair}\")\n", "print(\"-\"*100)\n", "print(\"\\t As signified by these numbers, the fair models are fairer than the standard model\")\n", "print(\"=\"*100)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "====================================================================================================\n", "EXPLORE metric\n", "Loss ratio (Standard model) : 3.2874276326186633. Is model fair: False\n", "Loss ratio (fair model - LogReg metric) : 1.0897408117340435. Is model fair: True\n", "Loss ratio (fair model - EXPLORE metric) : 1.063488311922447. Is model fair: True\n", "----------------------------------------------------------------------------------------------------\n", "\t As signified by these numbers, the fair models are fairer than the standard model\n", "====================================================================================================\n" ] } ], "source": [ "# Auditing using the SenSR Auditor + EXPLORE metric\n", "\n", "audit_nsteps = 1000\n", "audit_lr = 0.1\n", "\n", "auditor_explore = SenSRAuditor(loss_fn=loss_fn, distance_x=distance_x_explore, num_steps=audit_nsteps, lr=audit_lr, max_noise=0.5, min_noise=-0.5)\n", "\n", "audit_result_stdmodel = auditor_explore.audit(network_standard, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "audit_result_fairmodel_LR = auditor_explore.audit(network_fair_LR, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "audit_result_fairmodel_explore = auditor_explore.audit(network_fair_explore, X_test, y_test, lambda_param=10.0, audit_threshold=1.15)\n", "\n", "print(\"=\"*100)\n", "print(\"EXPLORE metric\")\n", "print(f\"Loss ratio (Standard model) : {audit_result_stdmodel.lower_bound}. Is model fair: {audit_result_stdmodel.is_model_fair}\")\n", "print(f\"Loss ratio (fair model - LogReg metric) : {audit_result_fairmodel_LR.lower_bound}. Is model fair: {audit_result_fairmodel_LR.is_model_fair}\")\n", "print(f\"Loss ratio (fair model - EXPLORE metric) : {audit_result_fairmodel_explore.lower_bound}. Is model fair: {audit_result_fairmodel_explore.is_model_fair}\")\n", "print(\"-\"*100)\n", "print(\"\\t As signified by these numbers, the fair models are fairer than the standard model\")\n", "print(\"=\"*100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('infairness')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "a2fcd21fd76dae422ddd233e5f13f95d1b9f33f06ee3b038abc60116b464585a" } } }, "nbformat": 4, "nbformat_minor": 4 }