{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7f6145ce-083e-4213-af7a-9f51804b2d4a",
   "metadata": {},
   "source": [
    "# Certified Training using Zonotopes with DeepZ"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6248072b",
   "metadata": {},
   "source": [
    "In this notebook we will look at how different training strategies impact the certified performance, and see that by using certified training very high levels of certification can be obtained. \n",
    "\n",
    "The drawback is that using certified training is computatianally expensive, even when compared to PGD adversarial training.\n",
    "\n",
    "The zonotopes abstraction used here is defined by:\n",
    "\\begin{equation}\n",
    "    \\hat{x} = \\eta_0 + \\sum_{i=1}^{i=N} \\eta_i \\epsilon_i \n",
    "\\end{equation}\n",
    "\n",
    "where $\\eta_0$ is the central vector, $\\epsilon_i$ are noise symbols, $\\eta_i$ are coefficients representing deviations around $\\eta_0$.\n",
    "\n",
    "By pushing zonotopes through a neural network we can obatin an overapproximation of the worst case loss and optimise the weights for certifiable robustness."
   ]
  },
  {
   "attachments": {
    "ART_Cert_Training.png": {
     "image/png": ""
    }
   },
   "cell_type": "markdown",
   "id": "4dd4686c",
   "metadata": {},
   "source": [
    "![ART_Cert_Training.png](attachment:ART_Cert_Training.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8229684f-9136-4e08-8743-a76600a3880f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/giulio/Documents/Projects/certified_training_art/venv/lib64/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "\n",
    "from torch import nn\n",
    "from sklearn.utils import shuffle\n",
    "\n",
    "from art.estimators.certification import deep_z\n",
    "from art.utils import load_mnist, preprocess, to_categorical\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "if not os.path.isdir('notebook_models/mnist/certified/'):\n",
    "    os.makedirs('notebook_models/mnist/certified/') \n",
    "\n",
    "if not os.path.isdir('notebook_models/mnist/pgd/'):\n",
    "    os.makedirs('notebook_models/mnist/pgd/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f92b485c-7010-4315-a5d8-c9a59b08d7c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For all of the demonstrations we will use this example MNIST model.\n",
    "\n",
    "class MNISTModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MNISTModel, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=1,\n",
    "                               out_channels=16,\n",
    "                               kernel_size=(4, 4),\n",
    "                               dilation=(1, 1),\n",
    "                               padding=(0, 0),\n",
    "                               stride=(2, 2))\n",
    "        self.conv2 = nn.Conv2d(in_channels=16,\n",
    "                               out_channels=32,\n",
    "                               dilation=(1, 1),\n",
    "                               padding=(0, 0),\n",
    "                               kernel_size=(4, 4),\n",
    "                               stride=(2, 2))\n",
    "        self.fc1 = nn.Linear(in_features=800,\n",
    "                             out_features=1000)\n",
    "        self.fc2 = nn.Linear(in_features=1000,\n",
    "                             out_features=10)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        if isinstance(x, np.ndarray):\n",
    "            x = torch.from_numpy(x).float().to(device)\n",
    "        x = self.relu(self.conv1(x))\n",
    "        x = self.relu(self.conv2(x))\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c743dd9b-7800-4777-9bdd-f0298db2942f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MNISTModel()\n",
    "model = model.to(device)\n",
    "opt = optim.Adam(model.parameters(), lr=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()\n",
    "\n",
    "x_test = np.squeeze(x_test)\n",
    "x_test = np.expand_dims(x_test, axis=1)\n",
    "y_test = np.argmax(y_test, axis=1)\n",
    "\n",
    "x_train = np.squeeze(x_train)\n",
    "x_train = np.expand_dims(x_train, axis=1)\n",
    "y_train = np.argmax(y_train, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab5146d2-8b04-4968-9d9f-409f2bf3cba9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "End of epoch 0 loss 0.4505019783973694\n",
      "End of epoch 1 loss 0.14715342223644257\n",
      "End of epoch 2 loss 0.08925174176692963\n",
      "End of epoch 3 loss 0.06519503891468048\n",
      "End of epoch 4 loss 0.052043616771698\n"
     ]
    }
   ],
   "source": [
    "# train the model normally\n",
    "\n",
    "def standard_train(model, opt, criterion, x, y, bsize=32, epochs=5):\n",
    "    num_of_batches = int(len(x) / bsize)\n",
    "    for epoch in range(epochs):\n",
    "        x, y = shuffle(x, y)\n",
    "        loss_list = []\n",
    "        for bnum in range(num_of_batches):\n",
    "            x_batch = np.copy(x[bnum * bsize:(bnum + 1) * bsize])\n",
    "            y_batch = np.copy(y[bnum * bsize:(bnum + 1) * bsize])\n",
    "\n",
    "            x_batch = torch.from_numpy(x_batch).float().to(device)\n",
    "            y_batch = torch.from_numpy(y_batch).type(torch.LongTensor).to(device)\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            opt.zero_grad()\n",
    "            outputs = model(x_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss_list.append(loss.data.cpu().detach().numpy())\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "        print('End of epoch {} loss {}'.format(epoch, np.mean(loss_list)))\n",
    "    return model\n",
    "\n",
    "model = standard_train(model=model,\n",
    "                       opt=opt,\n",
    "                       criterion=criterion,\n",
    "                       x=x_train,\n",
    "                       y=y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ac5fd81-fb23-4d00-a550-caa38fd1ebbb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc:  98.36\n"
     ]
    }
   ],
   "source": [
    "# lets now get the predicions for the MNIST test set and see how well our model is doing.\n",
    "with torch.no_grad():\n",
    "    test_preds = model(torch.from_numpy(x_test).float().to(device))\n",
    "\n",
    "test_preds = np.argmax(test_preds.cpu().detach().numpy(), axis=1)\n",
    "print('Test acc: ', np.mean(test_preds == y_test) * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2d4a4cc5-5ca0-41eb-982a-3e9feb055661",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.linear.Linear'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.linear.Linear'>\n",
      "Inferred reshape on op num 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/giulio/Documents/Projects/certified_training_art/adversarial-robustness-toolbox/art/estimators/certification/deep_z/pytorch.py:239: UserWarning: \n",
      "This estimator does not support networks which have dense layers before convolutional. We currently infer a reshape when a neural network goes from convolutional layers to dense layers. If your use case does not fall into this pattern then consider directly building a certifier network with the custom layers found in art.estimators.certification.deepz.deep_z.py\n",
      "\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "# But how robust are these predictions? \n",
    "# We can now examine this neural network's certified robustness. \n",
    "# We pass it into PytorchDeepZ. We will get a print out showing which \n",
    "# neural network layers have been registered. There will also be a \n",
    "# warning to tell us that PytorchDeepZ currently infers a reshape when \n",
    "# a neural network goes from using convolutional to dense layers. \n",
    "# This will cover the majority of use cases, however, if not then the \n",
    "# certification layers in art.estimators.certification.deepz.deep_z.py \n",
    "# can be used to directly build a certified model structure.\n",
    "\n",
    "zonotope_model = deep_z.PytorchDeepZ(model=model, \n",
    "                                     clip_values=(0, 1),\n",
    "                                     optimizer = optim.Adam(model.parameters(), lr=1e-4),\n",
    "                                     loss=nn.CrossEntropyLoss(), \n",
    "                                     input_shape=(1, 28, 28), \n",
    "                                     nb_classes=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8b7b0b1a-76e8-40ef-b919-7a686eef8ef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lets now see how robust our model is!\n",
    "# First we need to define what bound we need to check. \n",
    "# Here let's check for L infinity robustness with small bound of 0.15\n",
    "\n",
    "# lets now loop over the data to check its certified robustness:\n",
    "# we need to consider a single sample at a time as due to memory and compute footprints batching is not supported.\n",
    "# In this demo we will look at the first 50 samples of the MNIST test data.\n",
    "\n",
    "original_x = np.copy(x_test)\n",
    "def certification_loop(model, x, y, preds, bound):\n",
    "    num_certified = 0\n",
    "    num_correct = 0\n",
    "    for i, (sample, pred, label) in enumerate(zip(x[:50], preds[:50], y[:50])):\n",
    "\n",
    "        # we make the matrix representing the allowable perturbations. \n",
    "        # we have 28*28 features and each one can be manipulated independently requiring a different row.\n",
    "        # hence a 784*784 matrix.\n",
    "        eps_bound = np.eye(784) * bound\n",
    "\n",
    "        # we then need to adjust the raw data with the eps bounds to take into account\n",
    "        # the allowable range of 0 - 1 for pixel data.\n",
    "        # We provide a simple function to do this preprocessing for image data.\n",
    "        # However if your use case is not supported then a custom pre-processor function will need to be written.\n",
    "        sample, eps_bound = model.pre_process(cent=sample, \n",
    "                                              eps=eps_bound)\n",
    "        sample = np.expand_dims(sample, axis=0)\n",
    "\n",
    "        # We pass the data sample and the eps bound to the certifier along with the prediction that was made\n",
    "        # for the datapoint. \n",
    "        # A boolean is returned signifying if it can have its class changed under the given bound.\n",
    "        is_certified = model.certify(cent=sample,\n",
    "                                     eps=eps_bound,\n",
    "                                     prediction=pred)\n",
    "\n",
    "        if pred == label:\n",
    "            num_correct +=1\n",
    "            if is_certified:\n",
    "                num_certified +=1 \n",
    "\n",
    "        print('Classified Correct {}/{} and also certified {}/{}'.format(num_correct, i+1, num_certified, i+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "26d0ea51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classified Correct 1/1 and also certified 0/1\n",
      "Classified Correct 2/2 and also certified 0/2\n",
      "Classified Correct 3/3 and also certified 0/3\n",
      "Classified Correct 4/4 and also certified 0/4\n",
      "Classified Correct 5/5 and also certified 0/5\n",
      "Classified Correct 6/6 and also certified 0/6\n",
      "Classified Correct 7/7 and also certified 0/7\n",
      "Classified Correct 8/8 and also certified 0/8\n",
      "Classified Correct 9/9 and also certified 0/9\n",
      "Classified Correct 10/10 and also certified 0/10\n",
      "Classified Correct 11/11 and also certified 0/11\n",
      "Classified Correct 12/12 and also certified 0/12\n",
      "Classified Correct 13/13 and also certified 0/13\n",
      "Classified Correct 14/14 and also certified 0/14\n",
      "Classified Correct 15/15 and also certified 0/15\n",
      "Classified Correct 16/16 and also certified 0/16\n",
      "Classified Correct 17/17 and also certified 0/17\n",
      "Classified Correct 18/18 and also certified 0/18\n",
      "Classified Correct 19/19 and also certified 0/19\n",
      "Classified Correct 20/20 and also certified 0/20\n",
      "Classified Correct 21/21 and also certified 0/21\n",
      "Classified Correct 22/22 and also certified 0/22\n",
      "Classified Correct 23/23 and also certified 0/23\n",
      "Classified Correct 24/24 and also certified 0/24\n",
      "Classified Correct 25/25 and also certified 0/25\n",
      "Classified Correct 26/26 and also certified 0/26\n",
      "Classified Correct 27/27 and also certified 0/27\n",
      "Classified Correct 28/28 and also certified 0/28\n",
      "Classified Correct 29/29 and also certified 0/29\n",
      "Classified Correct 30/30 and also certified 0/30\n",
      "Classified Correct 31/31 and also certified 0/31\n",
      "Classified Correct 32/32 and also certified 0/32\n",
      "Classified Correct 33/33 and also certified 0/33\n",
      "Classified Correct 34/34 and also certified 0/34\n",
      "Classified Correct 35/35 and also certified 0/35\n",
      "Classified Correct 36/36 and also certified 0/36\n",
      "Classified Correct 37/37 and also certified 0/37\n",
      "Classified Correct 38/38 and also certified 0/38\n",
      "Classified Correct 39/39 and also certified 0/39\n",
      "Classified Correct 40/40 and also certified 0/40\n",
      "Classified Correct 41/41 and also certified 0/41\n",
      "Classified Correct 42/42 and also certified 0/42\n",
      "Classified Correct 43/43 and also certified 0/43\n",
      "Classified Correct 44/44 and also certified 0/44\n",
      "Classified Correct 45/45 and also certified 0/45\n",
      "Classified Correct 46/46 and also certified 0/46\n",
      "Classified Correct 47/47 and also certified 0/47\n",
      "Classified Correct 48/48 and also certified 0/48\n",
      "Classified Correct 49/49 and also certified 0/49\n",
      "Classified Correct 50/50 and also certified 0/50\n"
     ]
    }
   ],
   "source": [
    "# We can toggle how zonotope_model will process data though the method set_forward_mode.\n",
    "# with 'abstract' it will use abstract operations on the input (which is expected to be a zonotope).\n",
    "# with 'concrete' the nerual network will be run normally and expects regular data as input.\n",
    "\n",
    "# As we want to do certification analysis now, lets set the model to run in abstract mode\n",
    "\n",
    "zonotope_model.model.set_forward_mode('abstract')\n",
    "certification_loop(model=zonotope_model,\n",
    "                   x=np.copy(x_test),\n",
    "                   y=y_test,\n",
    "                   preds=test_preds,\n",
    "                   bound=0.15)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92f68aea",
   "metadata": {},
   "source": [
    "We can see that our NN has a very low certified robustness. You can modify the `bound` parameter in the function `certification_loop` to see how the certfied robustness varies. We can now try and improve this through different robust training stratgies. We will look at:\n",
    "\n",
    "+ Using certified adversarial training\n",
    "+ Using PGD adversarial training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9e25ac62-b575-468d-99d6-55d069e16731",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from art.defences.trainer import AdversarialTrainerCertifiedPytorch\n",
    "\n",
    "# We will now train the model to improve its certified accuracy. \n",
    "# Regular PGD training will boost certified performance, however even higher certification scores can \n",
    "# be obtained by training the nerual network with the objective of certified performance. Let's compare the \n",
    "# two methods.\n",
    "\n",
    "# NB! Certified Adversarial training takes about 9 hours on an NVIDIA V100 with the following parameters.\n",
    "\n",
    "pgd_params = {\"eps\": 0.3,\n",
    "              \"eps_step\": 0.05,\n",
    "              \"max_iter\": 20,\n",
    "              \"num_random_init\": 1,\n",
    "              \"batch_size\": 32,}\n",
    "\n",
    "trainer = AdversarialTrainerCertifiedPytorch(zonotope_model,\n",
    "                                             pgd_params=pgd_params,\n",
    "                                             batch_size=10,\n",
    "                                             bound=0.15)\n",
    "\n",
    "# Uncomment if you wish to train your own model, but it will take several hours!\n",
    "\n",
    "#trainer.fit(x_train,\n",
    "#            y_train,\n",
    "#            nb_epochs=30)\n",
    "# torch.save(trainer._classifier.model.state_dict(), 'notebook_models/mnist/certified/model.pt')\n",
    "\n",
    "# Here we will load a model which was trained using the above function.\n",
    "trainer._classifier.model.load_state_dict(torch.load('notebook_models/mnist/certified/model.pt', map_location=torch.device(device)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e7855a8c-0ab0-4ee6-a7b9-380d6c2360cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test acc:  98.7\n"
     ]
    }
   ],
   "source": [
    "# Like before we obtain the model's test time accuracy\n",
    "# Make sure to set the model in concrete forward mode\n",
    "\n",
    "with torch.no_grad():\n",
    "    trainer._classifier.model.set_forward_mode('concrete')\n",
    "    test_preds = model(torch.from_numpy(x_test).float().to(device))\n",
    "\n",
    "test_preds = np.argmax(test_preds.cpu().detach().numpy(), axis=1)\n",
    "print('Test acc: ', np.mean(test_preds == y_test) * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "65c3e186",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classified Correct 1/1 and also certified 1/1\n",
      "Classified Correct 2/2 and also certified 2/2\n",
      "Classified Correct 3/3 and also certified 3/3\n",
      "Classified Correct 4/4 and also certified 4/4\n",
      "Classified Correct 5/5 and also certified 5/5\n",
      "Classified Correct 6/6 and also certified 6/6\n",
      "Classified Correct 7/7 and also certified 7/7\n",
      "Classified Correct 8/8 and also certified 8/8\n",
      "Classified Correct 9/9 and also certified 8/9\n",
      "Classified Correct 10/10 and also certified 9/10\n",
      "Classified Correct 11/11 and also certified 10/11\n",
      "Classified Correct 12/12 and also certified 11/12\n",
      "Classified Correct 13/13 and also certified 12/13\n",
      "Classified Correct 14/14 and also certified 13/14\n",
      "Classified Correct 15/15 and also certified 14/15\n",
      "Classified Correct 16/16 and also certified 15/16\n",
      "Classified Correct 17/17 and also certified 16/17\n",
      "Classified Correct 18/18 and also certified 17/18\n",
      "Classified Correct 19/19 and also certified 17/19\n",
      "Classified Correct 20/20 and also certified 18/20\n",
      "Classified Correct 21/21 and also certified 19/21\n",
      "Classified Correct 22/22 and also certified 20/22\n",
      "Classified Correct 23/23 and also certified 21/23\n",
      "Classified Correct 24/24 and also certified 22/24\n",
      "Classified Correct 25/25 and also certified 22/25\n",
      "Classified Correct 26/26 and also certified 23/26\n",
      "Classified Correct 27/27 and also certified 24/27\n",
      "Classified Correct 28/28 and also certified 25/28\n",
      "Classified Correct 29/29 and also certified 26/29\n",
      "Classified Correct 30/30 and also certified 27/30\n",
      "Classified Correct 31/31 and also certified 28/31\n",
      "Classified Correct 32/32 and also certified 29/32\n",
      "Classified Correct 33/33 and also certified 30/33\n",
      "Classified Correct 34/34 and also certified 30/34\n",
      "Classified Correct 35/35 and also certified 31/35\n",
      "Classified Correct 36/36 and also certified 32/36\n",
      "Classified Correct 37/37 and also certified 33/37\n",
      "Classified Correct 38/38 and also certified 34/38\n",
      "Classified Correct 39/39 and also certified 35/39\n",
      "Classified Correct 40/40 and also certified 36/40\n",
      "Classified Correct 41/41 and also certified 37/41\n",
      "Classified Correct 42/42 and also certified 38/42\n",
      "Classified Correct 43/43 and also certified 39/43\n",
      "Classified Correct 44/44 and also certified 40/44\n",
      "Classified Correct 45/45 and also certified 41/45\n",
      "Classified Correct 46/46 and also certified 42/46\n",
      "Classified Correct 47/47 and also certified 43/47\n",
      "Classified Correct 48/48 and also certified 44/48\n",
      "Classified Correct 49/49 and also certified 45/49\n",
      "Classified Correct 50/50 and also certified 46/50\n"
     ]
    }
   ],
   "source": [
    "# then let's compare the certified robustness using certified training against the \n",
    "# model that was trained regularly. We will see that the model now has extremly high certified robustness.\n",
    "\n",
    "trainer._classifier.model.set_forward_mode('abstract')\n",
    "certification_loop(model=trainer._classifier,\n",
    "                   x=np.copy(x_test),\n",
    "                   y=y_test,\n",
    "                   preds=test_preds,\n",
    "                   bound=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b48fd5da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For the final comparison, we will look at how PGD trainng compares. \n",
    "# Let's make a new model and train it using art's AdversarialTrainer \n",
    "\n",
    "from art.attacks.evasion.projected_gradient_descent.projected_gradient_descent import ProjectedGradientDescent\n",
    "from art.estimators.classification import PyTorchClassifier\n",
    "from art.defences.trainer import AdversarialTrainer\n",
    "\n",
    "pgd_params = {\"eps\": 0.3,\n",
    "              \"eps_step\": 0.01,\n",
    "              \"max_iter\": 30,\n",
    "              \"batch_size\": 32,\n",
    "              \"num_random_init\": 1}\n",
    "\n",
    "model = MNISTModel().to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "classifier = PyTorchClassifier(\n",
    "    model=model,\n",
    "    clip_values=(0, 1),\n",
    "    loss=criterion,\n",
    "    optimizer=optimizer,\n",
    "    input_shape=(1, 28, 28),\n",
    "    nb_classes=10,\n",
    ")\n",
    "\n",
    "attack = ProjectedGradientDescent(\n",
    "    estimator=classifier,\n",
    "    eps=pgd_params[\"eps\"],\n",
    "    eps_step=pgd_params[\"eps_step\"],\n",
    "    max_iter=pgd_params[\"max_iter\"],\n",
    "    num_random_init=pgd_params[\"num_random_init\"],\n",
    ")\n",
    "\n",
    "trainer = AdversarialTrainer(classifier, attack, ratio=1.0)\n",
    "\n",
    "# Uncomment if you wish to train your own pdg model, but it may take 30min to ~2 hours depending \n",
    "# on your hardware.\n",
    "\n",
    "# trainer.fit(x_train, y_train, nb_epochs=20, batch_size=32)\n",
    "# torch.save(trainer._classifier.model.state_dict(), 'notebook_models/mnist/certified/pgd_model.pt')\n",
    "\n",
    "# Here we will load a model which was trained using the above function.\n",
    "trainer._classifier.model.load_state_dict(torch.load('notebook_models/mnist/pgd/pgd_model.pt', map_location=torch.device('cpu')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fddea1c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.conv.Conv2d'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.linear.Linear'>\n",
      "registered <class 'torch.nn.modules.activation.ReLU'>\n",
      "registered <class 'torch.nn.modules.linear.Linear'>\n",
      "Inferred reshape on op num 4\n",
      "Test acc:  98.67\n"
     ]
    }
   ],
   "source": [
    "# Let's get the model we trained with PGD and pass it into PytorchDeepZ so we can use \n",
    "# the certification methods on the model\n",
    "\n",
    "zonotope_model = deep_z.PytorchDeepZ(model=trainer._classifier.model, \n",
    "                                     clip_values=(0, 1),\n",
    "                                     optimizer = optim.Adam(model.parameters(), lr=1e-4),\n",
    "                                     loss=nn.CrossEntropyLoss(), \n",
    "                                     input_shape=(1, 28, 28), \n",
    "                                     nb_classes=10)\n",
    "\n",
    "# Get the test time predictions. As always, set the forward_mode to the correct type\n",
    "with torch.no_grad():\n",
    "    zonotope_model.model.set_forward_mode('concrete')\n",
    "    test_preds = zonotope_model.model.forward(torch.from_numpy(x_test).float().to(device))\n",
    "\n",
    "test_preds = np.argmax(test_preds.cpu().detach().numpy(), axis=1)\n",
    "print('Test acc: ', np.mean(test_preds == y_test) * 100)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f1edd03d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classified Correct 1/1 and also certified 1/1\n",
      "Classified Correct 2/2 and also certified 2/2\n",
      "Classified Correct 3/3 and also certified 3/3\n",
      "Classified Correct 4/4 and also certified 3/4\n",
      "Classified Correct 5/5 and also certified 3/5\n",
      "Classified Correct 6/6 and also certified 4/6\n",
      "Classified Correct 7/7 and also certified 4/7\n",
      "Classified Correct 8/8 and also certified 4/8\n",
      "Classified Correct 9/9 and also certified 4/9\n",
      "Classified Correct 10/10 and also certified 4/10\n",
      "Classified Correct 11/11 and also certified 5/11\n",
      "Classified Correct 12/12 and also certified 5/12\n",
      "Classified Correct 13/13 and also certified 5/13\n",
      "Classified Correct 14/14 and also certified 6/14\n",
      "Classified Correct 15/15 and also certified 7/15\n",
      "Classified Correct 16/16 and also certified 7/16\n",
      "Classified Correct 17/17 and also certified 7/17\n",
      "Classified Correct 18/18 and also certified 8/18\n",
      "Classified Correct 19/19 and also certified 8/19\n",
      "Classified Correct 20/20 and also certified 9/20\n",
      "Classified Correct 21/21 and also certified 9/21\n",
      "Classified Correct 22/22 and also certified 9/22\n",
      "Classified Correct 23/23 and also certified 9/23\n",
      "Classified Correct 24/24 and also certified 10/24\n",
      "Classified Correct 25/25 and also certified 10/25\n",
      "Classified Correct 26/26 and also certified 10/26\n",
      "Classified Correct 27/27 and also certified 11/27\n",
      "Classified Correct 28/28 and also certified 12/28\n",
      "Classified Correct 29/29 and also certified 13/29\n",
      "Classified Correct 30/30 and also certified 14/30\n",
      "Classified Correct 31/31 and also certified 15/31\n",
      "Classified Correct 32/32 and also certified 16/32\n",
      "Classified Correct 33/33 and also certified 17/33\n",
      "Classified Correct 34/34 and also certified 17/34\n",
      "Classified Correct 35/35 and also certified 18/35\n",
      "Classified Correct 36/36 and also certified 18/36\n",
      "Classified Correct 37/37 and also certified 18/37\n",
      "Classified Correct 38/38 and also certified 19/38\n",
      "Classified Correct 39/39 and also certified 19/39\n",
      "Classified Correct 40/40 and also certified 20/40\n",
      "Classified Correct 41/41 and also certified 21/41\n",
      "Classified Correct 42/42 and also certified 21/42\n",
      "Classified Correct 43/43 and also certified 22/43\n",
      "Classified Correct 44/44 and also certified 22/44\n",
      "Classified Correct 45/45 and also certified 23/45\n",
      "Classified Correct 46/46 and also certified 23/46\n",
      "Classified Correct 47/47 and also certified 24/47\n",
      "Classified Correct 48/48 and also certified 24/48\n",
      "Classified Correct 49/49 and also certified 25/49\n",
      "Classified Correct 50/50 and also certified 26/50\n"
     ]
    }
   ],
   "source": [
    "# And finally, lets see the PGD model's certified robustness.\n",
    "\n",
    "# We set the appropriate forward mode\n",
    "zonotope_model.model.set_forward_mode('abstract')\n",
    "\n",
    "# And then run the certification. We see that although PGD trainig boosts the certified robustness\n",
    "# in comparison to normal training, it's performance is much weaker than certified training.\n",
    "certification_loop(model=zonotope_model,\n",
    "                   x=np.copy(x_test),\n",
    "                   y=y_test,\n",
    "                   preds=test_preds,\n",
    "                   bound=0.15)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}