{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cac5c830",
   "metadata": {},
   "source": [
    "# Certification of Robustness using Zonotopes with DeepZ"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e03d16f6",
   "metadata": {},
   "source": [
    "In this notebook we will demonstrate the usage of certification using zonotopes within ART. With deterministic certification methods such as DeepZ we can have a guarantee if a datapoint could have its class changed under a given bound. This method was originally proposed in: https://papers.nips.cc/paper/2018/file/f2f446980d8e971ef3da97af089481c3-Paper.pdf\n",
    "\n",
    "The zonotopes abstraction used here is defined by:\n",
    "\\begin{equation}\n",
    "    \\hat{x} = \\eta_0 + \\sum_{i=1}^{i=N} \\eta_i \\epsilon_i \n",
    "\\end{equation}\n",
    "\n",
    "where $\\eta_0$ is the central vector, $\\epsilon_i$ are noise symbols, $\\eta_i$ are coefficients representing deviations around $\\eta_0$.\n",
    "\n",
    "We can illustrate a 2D toy example of this below in which the initial datapoint has two features, with a central vector of [0.25, 0.25] and these features both have noise terms of [0.25, 0.25]. We push this zonotope through the neural network and show it's intermediate shapes:\n"
   ]
  },
  {
   "attachments": {
    "zonotope_picture.png.png": {
     "image/png": ""
    }
   },
   "cell_type": "markdown",
   "id": "57e95744",
   "metadata": {},
   "source": [
    "![zonotope_picture.png.png](attachment:zonotope_picture.png.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "748cfc6a",
   "metadata": {},
   "source": [
    "We can see that the zonotope changes shape as it is passed through the neural network. When passing though a ReLU it gains another term (going from 2 sets of parallel lines to 3). We can then check if the final zonotope crosses any desicion boundaries and say if a point is certified.\n",
    "\n",
    "Let's see how to use this method in ART!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "785902b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "\n",
    "from torch import nn\n",
    "from sklearn.utils import shuffle\n",
    "\n",
    "from art.estimators.certification import deep_z\n",
    "from art.utils import load_mnist, preprocess, to_categorical\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5df9e108",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We make an example pytorch classifier\n",
    "\n",
    "class MNISTModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNISTModel, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=1,\n",
    "                               out_channels=32,\n",
    "                               kernel_size=(4, 4),\n",
    "                               stride=(2, 2),\n",
    "                               dilation=(1, 1),\n",
    "                               padding=(0, 0))\n",
    "        self.conv2 = nn.Conv2d(in_channels=32,\n",
    "                               out_channels=32,\n",
    "                               kernel_size=(4, 4),\n",
    "                               stride=(2, 2),\n",
    "                               dilation=(1, 1),\n",
    "                               padding=(0, 0))\n",
    "        self.fc1 = nn.Linear(in_features=800,\n",
    "                             out_features=10)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        if isinstance(x, np.ndarray):\n",
    "            x = torch.from_numpy(x).float().to(device)\n",
    "        x = self.relu(self.conv1(x))\n",
    "        x = self.relu(self.conv2(x))\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "148604f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MNISTModel()\n",
    "opt = optim.Adam(model.parameters(), lr=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()\n",
    "\n",
    "x_test = np.squeeze(x_test)\n",
    "x_test = np.expand_dims(x_test, axis=1)\n",
    "y_test = np.argmax(y_test, axis=1)\n",
    "\n",
    "x_train = np.squeeze(x_train)\n",
    "x_train = np.expand_dims(x_train, axis=1)\n",
    "y_train = np.argmax(y_train, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e240be5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "End of epoch 0 loss 0.5815373063087463\n",
      "End of epoch 1 loss 0.2648811340332031\n",
      "End of epoch 2 loss 0.18593080341815948\n",
      "End of epoch 3 loss 0.1360677033662796\n",
      "End of epoch 4 loss 0.10795646160840988\n"
     ]
    }
   ],
   "source": [
    "# train the model normally\n",
    "\n",
    "def standard_train(model, opt, criterion, x, y, bsize=32, epochs=5):\n",
    "    num_of_batches = int(len(x) / bsize)\n",
    "    for epoch in range(epochs):\n",
    "        x, y = shuffle(x, y)\n",
    "        loss_list = []\n",
    "        for bnum in range(num_of_batches):\n",
    "            x_batch = np.copy(x[bnum * bsize:(bnum + 1) * bsize])\n",
    "            y_batch = np.copy(y[bnum * bsize:(bnum + 1) * bsize])\n",
    "\n",
    "            x_batch = torch.from_numpy(x_batch).float().to(device)\n",
    "            y_batch = torch.from_numpy(y_batch).type(torch.LongTensor).to(device)\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            opt.zero_grad()\n",
    "            outputs = model(x_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss_list.append(loss.data)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        print('End of epoch {} loss {}'.format(epoch, np.mean(loss_list)))\n",
    "    return model\n",
    "\n",
    "model = standard_train(model=model,\n",
    "                       opt=opt,\n",
    "                       criterion=criterion,\n",
    "                       x=x_train,\n",
    "                       y=y_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f358bef6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc:  97.46000000000001\n"
     ]
    }
   ],
   "source": [
    "# lets now get the predicions for the MNIST test set and see how well our model is doing.\n",
    "with torch.no_grad():\n",
    "    test_preds = model(torch.from_numpy(x_test).float().to(device))\n",
    "\n",
    "test_preds = np.argmax(test_preds.cpu().detach().numpy(), axis=1)\n",
    "print('Test acc: ', np.mean(test_preds == y_test) * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "430edde4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.linear.Linear'>\n",
      "Inferred reshape on op num 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/giulio/Documents/Projects/AI2_for_ART/adversarial-robustness-toolbox/art/estimators/certification/deep_z/pytorch.py:90: UserWarning: \n",
      "This estimator does not support networks which have dense layers before convolutional. We currently infer a reshape when a neural network goes from convolutional layers to dense layers. If your use case does not fall into this pattern then consider directly building a certifier network with the custom layers found in art.estimators.certification.deepz.deep_z.py\n",
      "\n",
      "  \"\\nThis estimator does not support networks which have dense layers before convolutional. \"\n"
     ]
    }
   ],
   "source": [
    "# But how robust are these predictions? \n",
    "# We can now examine this neural network's certified robustness. \n",
    "# We pass it into PytorchDeepZ. We will get a print out showing which \n",
    "# neural network layers have been registered. There will also be a \n",
    "# warning to tell us that PytorchDeepZ currently infers a reshape when \n",
    "# a neural network goes from using convolutional to dense layers. \n",
    "# This will cover the majority of use cases, however, if not then the \n",
    "# certification layers in art.estimators.certification.deepz.deep_z.py \n",
    "# can be used to directly build a certified model structure.\n",
    "\n",
    "zonotope_model = deep_z.PytorchDeepZ(model=model, \n",
    "                                     clip_values=(0, 1), \n",
    "                                     loss=nn.CrossEntropyLoss(), \n",
    "                                     input_shape=(1, 28, 28), \n",
    "                                     nb_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f5749a48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classified Correct 1/1 and also certified 1/1\n",
      "Classified Correct 2/2 and also certified 2/2\n",
      "Classified Correct 3/3 and also certified 2/3\n",
      "Classified Correct 4/4 and also certified 3/4\n",
      "Classified Correct 5/5 and also certified 4/5\n",
      "Classified Correct 6/6 and also certified 4/6\n",
      "Classified Correct 7/7 and also certified 4/7\n",
      "Classified Correct 8/8 and also certified 4/8\n",
      "Classified Correct 9/9 and also certified 4/9\n",
      "Classified Correct 10/10 and also certified 4/10\n",
      "Classified Correct 11/11 and also certified 5/11\n",
      "Classified Correct 12/12 and also certified 6/12\n",
      "Classified Correct 13/13 and also certified 7/13\n",
      "Classified Correct 14/14 and also certified 8/14\n",
      "Classified Correct 15/15 and also certified 9/15\n",
      "Classified Correct 16/16 and also certified 10/16\n",
      "Classified Correct 17/17 and also certified 11/17\n",
      "Classified Correct 18/18 and also certified 12/18\n",
      "Classified Correct 19/19 and also certified 12/19\n",
      "Classified Correct 20/20 and also certified 13/20\n",
      "Classified Correct 21/21 and also certified 14/21\n",
      "Classified Correct 22/22 and also certified 14/22\n",
      "Classified Correct 23/23 and also certified 15/23\n",
      "Classified Correct 24/24 and also certified 16/24\n",
      "Classified Correct 25/25 and also certified 16/25\n",
      "Classified Correct 26/26 and also certified 17/26\n",
      "Classified Correct 27/27 and also certified 18/27\n",
      "Classified Correct 28/28 and also certified 19/28\n",
      "Classified Correct 29/29 and also certified 20/29\n",
      "Classified Correct 30/30 and also certified 20/30\n",
      "Classified Correct 31/31 and also certified 21/31\n",
      "Classified Correct 32/32 and also certified 21/32\n",
      "Classified Correct 33/33 and also certified 22/33\n",
      "Classified Correct 34/34 and also certified 22/34\n",
      "Classified Correct 35/35 and also certified 23/35\n",
      "Classified Correct 36/36 and also certified 24/36\n",
      "Classified Correct 37/37 and also certified 25/37\n",
      "Classified Correct 38/38 and also certified 25/38\n",
      "Classified Correct 39/39 and also certified 26/39\n",
      "Classified Correct 40/40 and also certified 26/40\n",
      "Classified Correct 41/41 and also certified 26/41\n",
      "Classified Correct 42/42 and also certified 26/42\n",
      "Classified Correct 43/43 and also certified 27/43\n",
      "Classified Correct 44/44 and also certified 27/44\n",
      "Classified Correct 45/45 and also certified 28/45\n",
      "Classified Correct 46/46 and also certified 28/46\n",
      "Classified Correct 47/47 and also certified 28/47\n",
      "Classified Correct 48/48 and also certified 29/48\n",
      "Classified Correct 49/49 and also certified 30/49\n",
      "Classified Correct 50/50 and also certified 31/50\n"
     ]
    }
   ],
   "source": [
    "# Lets now see how robust our model is!\n",
    "# First we need to define what bound we need to check. \n",
    "# Here let's check for L infinity robustness with small bound of 0.05\n",
    "\n",
    "bound = 0.05\n",
    "num_certified = 0\n",
    "num_correct = 0\n",
    "\n",
    "# lets now loop over the data to check its certified robustness:\n",
    "# we need to consider a single sample at a time as due to memory and compute footprints batching is not supported.\n",
    "# In this demo we will look at the first 50 samples of the MNIST test data.\n",
    "original_x = np.copy(x_test)\n",
    "for i, (sample, pred, label) in enumerate(zip(x_test[:50], test_preds[:50], y_test[:50])):\n",
    "    \n",
    "    # we make the matrix representing the allowable perturbations. \n",
    "    # we have 28*28 features and each one can be manipulated independently requiring a different row.\n",
    "    # hence a 784*784 matrix.\n",
    "    eps_bound = np.eye(784) * bound\n",
    "    \n",
    "    # we then need to adjust the raw data with the eps bounds to take into account\n",
    "    # the allowable range of 0 - 1 for pixel data.\n",
    "    # We provide a simple function to do this preprocessing for image data.\n",
    "    # However if your use case is not supported then a custom pre-processor function will need to be written.\n",
    "    sample, eps_bound = zonotope_model.pre_process(cent=sample, \n",
    "                                                   eps=eps_bound)\n",
    "    sample = np.expand_dims(sample, axis=0)\n",
    "\n",
    "    # We pass the data sample and the eps bound to the certifier along with the prediction that was made\n",
    "    # for the datapoint. \n",
    "    # A boolean is returned signifying if it can have its class changed under the given bound.\n",
    "    is_certified = zonotope_model.certify(cent=sample,\n",
    "                                          eps=eps_bound,\n",
    "                                          prediction=pred)\n",
    "    \n",
    "    if pred == label:\n",
    "        num_correct +=1\n",
    "        if is_certified:\n",
    "            num_certified +=1 \n",
    "    \n",
    "    print('Classified Correct {}/{} and also certified {}/{}'.format(num_correct, i+1, num_certified, i+1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b858fc2a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc:  92.0\n"
     ]
    }
   ],
   "source": [
    "# we can then compare this to the empirical PGD performance\n",
    "\n",
    "from art.estimators.classification import PyTorchClassifier\n",
    "from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent import ProjectedGradientDescent\n",
    "\n",
    "classifier = PyTorchClassifier(\n",
    "    model=model,\n",
    "    clip_values=(0.0, 1.0),\n",
    "    loss=criterion,\n",
    "    optimizer=opt,\n",
    "    input_shape=(1, 28, 28),\n",
    "    nb_classes=10,\n",
    ")\n",
    "\n",
    "attack = ProjectedGradientDescent(classifier, eps=0.05, eps_step=0.01, verbose=False)\n",
    "x_train_adv = attack.generate(x_test[:50].astype('float32'))\n",
    "y_adv_pred = classifier.predict(torch.from_numpy(x_train_adv).float().to(device))\n",
    "y_adv_pred = np.argmax(y_adv_pred, axis=1)\n",
    "print('Test acc: ', np.mean(y_adv_pred == y_test[:50]) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f60d178",
   "metadata": {},
   "source": [
    "we can see that the empirical test accuracy is much higher than the certifiable performance. This is because with certifiable techniques we will be providing a lower bound on the performance: there may well be datapoints that the certifier says are unsafe, but in fact cannot have their class changed."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}