{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Running attribute inference attacks on the Nursery data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will show how to run both black-box and white-box inference attacks. This will be demonstrated on the Nursery dataset (original dataset can be found here: https://archive.ics.uci.edu/ml/datasets/nursery). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preliminaries\n",
    "\n",
    "In the case of the nursery dataset, the sensitive feature we want to infer is the 'social' feature. In the original dataset this is a categorical feature with 3 possible values. To make the attack more successful, we reduced this to two possible feature values by assigning the original value 'problematic' the new value 1, and the other original values were assigned the new value 0.\n",
    "\n",
    "We have also already preprocessed the dataset such that all categorical features are one-hot encoded, and the data was scaled using sklearn's StandardScaler."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, os.path.abspath('..'))\n",
    "\n",
    "from art.utils import load_nursery\n",
    "\n",
    "(x_train, y_train), (x_test, y_test), _, _ = load_nursery(test_set=0.5, transform_social=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train decision tree model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base model accuracy:  0.9705155912318617\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from art.estimators.classification.scikitlearn import ScikitlearnDecisionTreeClassifier\n",
    "\n",
    "model = DecisionTreeClassifier()\n",
    "model.fit(x_train, y_train)\n",
    "art_classifier = ScikitlearnDecisionTreeClassifier(model)\n",
    "\n",
    "print('Base model accuracy: ', model.score(x_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attack\n",
    "### Black-box attack\n",
    "The black-box attack basically trains an additional classifier (called the attack model) to predict the attacked feature's value from the remaining n-1 features as well as the original (attacked) model's predictions.\n",
    "#### Train attack model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from art.attacks.inference.attribute_inference import AttributeInferenceBlackBox\n",
    "\n",
    "attack_train_ratio = 0.5\n",
    "attack_train_size = int(len(x_train) * attack_train_ratio)\n",
    "attack_test_size = int(len(x_train) * attack_train_ratio)\n",
    "attack_x_train = x_train[:attack_train_size]\n",
    "attack_y_train = y_train[:attack_train_size]\n",
    "attack_x_test = x_train[attack_train_size:]\n",
    "attack_y_test = y_train[attack_train_size:]\n",
    "\n",
    "attack_feature = 1  # social\n",
    "\n",
    "# get original model's predictions\n",
    "attack_x_test_predictions = np.array([np.argmax(arr) for arr in art_classifier.predict(attack_x_test)]).reshape(-1,1)\n",
    "# only attacked feature\n",
    "attack_x_test_feature = attack_x_test[:, attack_feature].copy().reshape(-1, 1)\n",
    "# training data without attacked feature\n",
    "attack_x_test = np.delete(attack_x_test, attack_feature, 1)\n",
    "\n",
    "bb_attack = AttributeInferenceBlackBox(art_classifier, attack_feature=attack_feature)\n",
    "\n",
    "# train attack model\n",
    "bb_attack.fit(attack_x_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Infer sensitive feature and check accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5998765050941649\n"
     ]
    }
   ],
   "source": [
    "# get inferred values\n",
    "values = [-0.70718864, 1.41404987]\n",
    "inferred_train_bb = bb_attack.infer(attack_x_test, pred=attack_x_test_predictions, values=values)\n",
    "# check accuracy\n",
    "train_acc = np.sum(inferred_train_bb == np.around(attack_x_test_feature, decimals=8).reshape(1,-1)) / len(inferred_train_bb)\n",
    "print(train_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This means that for 59% of the training set, the attacked feature is inferred correctly using this attack."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Whitebox attacks\n",
    "These two attacks do not train any additional model, they simply use additional information coded within the attacked decision tree model to compute the probability of each value of the attacked feature and outputs the value with the highest probability.\n",
    "### First attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6288978079654214\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.inference.attribute_inference import AttributeInferenceWhiteBoxLifestyleDecisionTree\n",
    "\n",
    "wb_attack = AttributeInferenceWhiteBoxLifestyleDecisionTree(art_classifier, attack_feature=attack_feature)\n",
    "\n",
    "priors = [3465 / 5183, 1718 / 5183]\n",
    "\n",
    "# get inferred values\n",
    "inferred_train_wb1 = wb_attack.infer(attack_x_test, attack_x_test_predictions, values=values, priors=priors)\n",
    "\n",
    "# check accuracy\n",
    "train_acc = np.sum(inferred_train_wb1 == np.around(attack_x_test_feature, decimals=8).reshape(1,-1)) / len(inferred_train_wb1)\n",
    "print(train_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Second attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7005248533497993\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.inference.attribute_inference import AttributeInferenceWhiteBoxDecisionTree\n",
    "\n",
    "wb2_attack = AttributeInferenceWhiteBoxDecisionTree(art_classifier, attack_feature=attack_feature)\n",
    "\n",
    "# get inferred values\n",
    "inferred_train_wb2 = wb2_attack.infer(attack_x_test, attack_x_test_predictions, values=values, priors=priors)\n",
    "\n",
    "# check accuracy\n",
    "train_acc = np.sum(inferred_train_wb2 == np.around(attack_x_test_feature, decimals=8).reshape(1,-1)) / len(inferred_train_wb2)\n",
    "print(train_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The white-box attacks are able to correctly infer the attacked feature value in 62% and 70% of the training set respectively. \n",
    "\n",
    "Now let's check the precision and recall:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.34232954545454547, 0.22439478584729983)\n",
      "(0.32320441988950277, 0.10893854748603352)\n",
      "(0.652046783625731, 0.20763500931098697)\n"
     ]
    }
   ],
   "source": [
    "def calc_precision_recall(predicted, actual, positive_value=1):\n",
    "    score = 0  # both predicted and actual are positive\n",
    "    num_positive_predicted = 0  # predicted positive\n",
    "    num_positive_actual = 0  # actual positive\n",
    "    for i in range(len(predicted)):\n",
    "        if predicted[i] == positive_value:\n",
    "            num_positive_predicted += 1\n",
    "        if actual[i] == positive_value:\n",
    "            num_positive_actual += 1\n",
    "        if predicted[i] == actual[i]:\n",
    "            if predicted[i] == positive_value:\n",
    "                score += 1\n",
    "    \n",
    "    if num_positive_predicted == 0:\n",
    "        precision = 1\n",
    "    else:\n",
    "        precision = score / num_positive_predicted  # the fraction of predicted “Yes” responses that are correct\n",
    "    if num_positive_actual == 0:\n",
    "        recall = 1\n",
    "    else:\n",
    "        recall = score / num_positive_actual  # the fraction of “Yes” responses that are predicted correctly\n",
    "\n",
    "    return precision, recall\n",
    "    \n",
    "# black-box\n",
    "print(calc_precision_recall(inferred_train_bb, np.around(attack_x_test_feature, decimals=8), positive_value=1.41404987))\n",
    "# white-box 1\n",
    "print(calc_precision_recall(inferred_train_wb1, np.around(attack_x_test_feature, decimals=8), positive_value=1.41404987))\n",
    "# white-box 2\n",
    "print(calc_precision_recall(inferred_train_wb2, np.around(attack_x_test_feature, decimals=8), positive_value=1.41404987))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To verify the significance of these results, we now run a baseline attack that uses only the remaining features to try to predict the value of the attacked feature, with no use of the model itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5433775856745909\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.inference.attribute_inference import AttributeInferenceBaseline\n",
    "\n",
    "baseline_attack = AttributeInferenceBaseline(attack_feature=attack_feature)\n",
    "\n",
    "# train attack model\n",
    "baseline_attack.fit(attack_x_train)\n",
    "# infer values\n",
    "inferred_train_baseline = baseline_attack.infer(attack_x_test, values=values)\n",
    "# check accuracy\n",
    "baseline_train_acc = np.sum(inferred_train_baseline == np.around(attack_x_test_feature, decimals=8).reshape(1,-1)) / len(inferred_train_baseline)\n",
    "print(baseline_train_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that both the black-box and white-box attacks do better than the baseline."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Membership based attack\n",
    "In this attack the idea is to find the target feature value that maximizes the membership attack confidence, indicating that this is the most probable value for member samples. It can be based on any membership attack (either black-box or white-box) as long as it supports the given model.\n",
    "\n",
    "### Train membership attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from art.attacks.inference.membership_inference import MembershipInferenceBlackBox\n",
    "\n",
    "mem_attack = MembershipInferenceBlackBox(art_classifier)\n",
    "\n",
    "mem_attack.fit(x_train[:attack_train_size], y_train[:attack_train_size], x_test[:attack_test_size], y_test[:attack_test_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply attribute attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6335288669342389\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.inference.attribute_inference import AttributeInferenceMembership\n",
    "\n",
    "attack = AttributeInferenceMembership(art_classifier, mem_attack, attack_feature=attack_feature)\n",
    "\n",
    "# infer values\n",
    "inferred_train = attack.infer(attack_x_test, attack_y_test, values=values)\n",
    "\n",
    "# check accuracy\n",
    "train_acc = np.sum(inferred_train == np.around(attack_x_test_feature, decimals=8).reshape(1,-1)) / len(inferred_train)\n",
    "print(train_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that this attack does slightly better than the regular black-box attack, even though it still assumes only black-box access to the model (employs a black-box membership attack). But it is not as good as the white-box attacks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}