{ "cells": [ { "cell_type": "markdown", "id": "58063edd", "metadata": {}, "source": [ "# Certification of Vision Transformers" ] }, { "cell_type": "markdown", "id": "0438abb9", "metadata": {}, "source": [ "In this notebook we will go over how to use the PyTorchSmoothedViT tool and be able to certify vision transformers against patch attacks!\n", "\n", "### Overview\n", "\n", "This method was introduced in Certified Patch Robustness via Smoothed Vision Transformers (https://arxiv.org/abs/2110.07719). The core technique is one of *image ablations*, where the image is blanked out except for certain regions. By ablating the input in different ways every time we can obtain many predicitons for a single input. Now, as we are ablating large parts of the image the attacker's patch attack is also getting removed in many predictions. Based on factors like the size of the adversarial patch and the size of the retained part of the image the attacker will only be able to influence a limited number of predictions. In fact, if the attacker has a $m x m$ patch attack and the retained part of the image is a column of width $s$ then the maximum number of predictions $\\Delta$ that could be affected are: \n", "\n", "
$\\Delta = m + s - 1$
\n", "\n", "Based on this relationship we can derive a simple but effective criterion that if we are making many predictions for an image and the highest predicted class $c_t$ has been predicted $k_t$ times and the second most predicted class $c_{t-1}$ has been predicted $k_{t-1}$ times then we have a certified prediction for $c_t$ if: \n", "\n", "\n", "$k_t - k_{t-1} > 2\\Delta$
\n", "\n", "Intuitivly we are saying that even if $k$ predictions were adversarially influenced and those predictions were to change, then the model will *still* have predicted class $c_t$.\n", "\n", "### What's special about Vision Transformers?\n", "\n", "The formulation above is very generic and it can be applied to any nerual network model, in fact the original paper which proposed it (https://arxiv.org/abs/2110.07719) considered the case with convolutional nerual networks. \n", "\n", "However, Vision Transformers (ViTs) are well siuted to this task of predicting with vision ablations for two key reasons: \n", "\n", "+ ViTs first tokenize the input into a series of image regions which get embedded and then processed through the neural network. Thus, by considering the input as a set of tokens we can drop tokens which correspond to fully masked (i.e ablated)regions significantly saving on the compute costs. \n", "\n", "+ Secondly, the ViT's self attention layer enables sharing of information globally at every layer. In contrast convolutional neural networks build up the receptive field over a series of layers. Hence, ViTs can be more effective at classifying an image based on its small unablated regions.\n", "\n", "Let's see how to use these tools!" ] }, { "cell_type": "code", "execution_count": 1, "id": "aeb27667", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import numpy as np\n", "import torch\n", "\n", "sys.path.append(\"..\")\n", "from torchvision import datasets\n", "from matplotlib import pyplot as plt\n", "\n", "# The core tool is PyTorchSmoothedViT which can be imported as follows:\n", "from art.estimators.certification.derandomized_smoothing import PyTorchDeRandomizedSmoothing\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 2, "id": "80541a3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "# Function to fetch the cifar-10 data\n", "def get_cifar_data():\n", " \"\"\"\n", " Get CIFAR-10 data.\n", " :return: cifar train/test data.\n", " \"\"\"\n", " train_set = datasets.CIFAR10('./data', train=True, download=True)\n", " test_set = datasets.CIFAR10('./data', train=False, download=True)\n", "\n", " x_train = train_set.data.astype(np.float32)\n", " y_train = np.asarray(train_set.targets)\n", "\n", " x_test = test_set.data.astype(np.float32)\n", " y_test = np.asarray(test_set.targets)\n", "\n", " x_train = np.moveaxis(x_train, [3], [1])\n", " x_test = np.moveaxis(x_test, [3], [1])\n", "\n", " x_train = x_train / 255.0\n", " x_test = x_test / 255.0\n", "\n", " return (x_train, y_train), (x_test, y_test)\n", "\n", "\n", "(x_train, y_train), (x_test, y_test) = get_cifar_data()" ] }, { "cell_type": "code", "execution_count": 3, "id": "2ac0c5b3", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "['vit_base_patch8_224',\n", " 'vit_base_patch16_18x2_224',\n", " 'vit_base_patch16_224',\n", " 'vit_base_patch16_224_miil',\n", " 'vit_base_patch16_384',\n", " 'vit_base_patch16_clip_224',\n", " 'vit_base_patch16_clip_384',\n", " 'vit_base_patch16_gap_224',\n", " 'vit_base_patch16_plus_240',\n", " 'vit_base_patch16_rpn_224',\n", " 'vit_base_patch16_xp_224',\n", " 'vit_base_patch32_224',\n", " 'vit_base_patch32_384',\n", " 'vit_base_patch32_clip_224',\n", " 'vit_base_patch32_clip_384',\n", " 'vit_base_patch32_clip_448',\n", " 'vit_base_patch32_plus_256',\n", " 'vit_giant_patch14_224',\n", " 'vit_giant_patch14_clip_224',\n", " 'vit_gigantic_patch14_224',\n", " 'vit_gigantic_patch14_clip_224',\n", " 'vit_huge_patch14_224',\n", " 'vit_huge_patch14_clip_224',\n", " 'vit_huge_patch14_clip_336',\n", " 'vit_huge_patch14_xp_224',\n", " 'vit_large_patch14_224',\n", " 'vit_large_patch14_clip_224',\n", " 'vit_large_patch14_clip_336',\n", " 'vit_large_patch14_xp_224',\n", " 'vit_large_patch16_224',\n", " 'vit_large_patch16_384',\n", " 'vit_large_patch32_224',\n", " 'vit_large_patch32_384',\n", " 'vit_medium_patch16_gap_240',\n", " 'vit_medium_patch16_gap_256',\n", " 'vit_medium_patch16_gap_384',\n", " 'vit_small_patch16_18x2_224',\n", " 'vit_small_patch16_36x1_224',\n", " 'vit_small_patch16_224',\n", " 'vit_small_patch16_384',\n", " 'vit_small_patch32_224',\n", " 'vit_small_patch32_384',\n", " 'vit_tiny_patch16_224',\n", " 'vit_tiny_patch16_384']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# There are a few ways we can interface with PyTorchSmoothedViT. \n", "# The most direct way to get setup is by specifying the name of a supported transformer.\n", "# Behind the scenes we are using the timm library (link: https://github.com/huggingface/pytorch-image-models).\n", "\n", "\n", "# We currently support ViTs generated via: \n", "# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py\n", "# Support for other architectures can be added in. Consider raising a feature or pull request to have \n", "# additional models supported.\n", "\n", "# We can see all the models supported by using the .get_models() method:\n", "PyTorchDeRandomizedSmoothing.get_models()" ] }, { "cell_type": "code", "execution_count": 4, "id": "e8bac618", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Running algorithm: salman2021\n", "INFO:root:Converting Adam Optimiser\n", "WARNING:art.estimators.certification.derandomized_smoothing.pytorch: ViT expects input shape of: (3, 224, 224) but (3, 32, 32) specified as the input shape. The input will be rescaled to (3, 224, 224)\n", "INFO:art.estimators.classification.pytorch:Inferred 9 hidden layers on PyTorch classifier.\n", "INFO:art.estimators.certification.derandomized_smoothing.pytorch:PyTorchViT(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (patch_drop): Identity()\n", " (norm_pre): Identity()\n", " (blocks): Sequential(\n", " (0): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (6): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (7): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (8): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (9): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (10): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (11): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " )\n", " (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (fc_norm): Identity()\n", " (head_drop): Dropout(p=0.0, inplace=False)\n", " (head): Linear(in_features=384, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "import timm\n", "\n", "# We can setup the PyTorchSmoothedViT if we start with a ViT model directly.\n", "\n", "vit_model = timm.create_model('vit_small_patch16_224')\n", "optimizer = torch.optim.Adam(vit_model.parameters(), lr=1e-4)\n", "\n", "art_model = PyTorchDeRandomizedSmoothing(model=vit_model, # Name of the model acitecture to load\n", " loss=torch.nn.CrossEntropyLoss(), # loss function to use\n", " optimizer=optimizer, # the optimizer to use: note! this is not initialised here we just supply the class!\n", " input_shape=(3, 32, 32), # the input shape of the data: Note! that if this is a different shape to what the ViT expects it will be re-scaled\n", " nb_classes=10,\n", " ablation_size=4, # Size of the retained column\n", " replace_last_layer=True, # Replace the last layer with a new set of weights to fine tune on new data\n", " load_pretrained=True) # if to load pre-trained weights for the ViT" ] }, { "cell_type": "code", "execution_count": 5, "id": "353ef5a6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Running algorithm: salman2021\n", "INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/vit_small_patch16_224.augreg_in21k_ft_in1k)\n", "INFO:timm.models._hub:[timm/vit_small_patch16_224.augreg_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.\n", "WARNING:art.estimators.certification.derandomized_smoothing.pytorch: ViT expects input shape of: (3, 224, 224) but (3, 32, 32) specified as the input shape. The input will be rescaled to (3, 224, 224)\n", "INFO:art.estimators.classification.pytorch:Inferred 9 hidden layers on PyTorch classifier.\n", "INFO:art.estimators.certification.derandomized_smoothing.pytorch:PyTorchViT(\n", " (patch_embed): PatchEmbed(\n", " (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))\n", " (norm): Identity()\n", " )\n", " (pos_drop): Dropout(p=0.0, inplace=False)\n", " (patch_drop): Identity()\n", " (norm_pre): Identity()\n", " (blocks): Sequential(\n", " (0): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (1): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (2): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (3): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (4): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (5): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (6): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (7): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (8): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (9): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (10): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " (11): Block(\n", " (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (attn): Attention(\n", " (qkv): Linear(in_features=384, out_features=1152, bias=True)\n", " (q_norm): Identity()\n", " (k_norm): Identity()\n", " (attn_drop): Dropout(p=0.0, inplace=False)\n", " (proj): Linear(in_features=384, out_features=384, bias=True)\n", " (proj_drop): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls1): Identity()\n", " (drop_path1): Identity()\n", " (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (mlp): Mlp(\n", " (fc1): Linear(in_features=384, out_features=1536, bias=True)\n", " (act): GELU(approximate='none')\n", " (drop1): Dropout(p=0.0, inplace=False)\n", " (norm): Identity()\n", " (fc2): Linear(in_features=1536, out_features=384, bias=True)\n", " (drop2): Dropout(p=0.0, inplace=False)\n", " )\n", " (ls2): Identity()\n", " (drop_path2): Identity()\n", " )\n", " )\n", " (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n", " (fc_norm): Identity()\n", " (head_drop): Dropout(p=0.0, inplace=False)\n", " (head): Linear(in_features=384, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "# Or we can just feed in the model name and ART will internally create the ViT.\n", "\n", "art_model = PyTorchDeRandomizedSmoothing(model='vit_small_patch16_224', # Name of the model acitecture to load\n", " loss=torch.nn.CrossEntropyLoss(), # loss function to use\n", " optimizer=torch.optim.SGD, # the optimizer to use: note! this is not initialised here we just supply the class!\n", " optimizer_params={\"lr\": 0.01}, # the parameters to use\n", " input_shape=(3, 32, 32), # the input shape of the data: Note! that if this is a different shape to what the ViT expects it will be re-scaled\n", " nb_classes=10,\n", " ablation_size=4, # Size of the retained column\n", " replace_last_layer=True, # Replace the last layer with a new set of weights to fine tune on new data\n", " load_pretrained=True) # if to load pre-trained weights for the ViT" ] }, { "cell_type": "markdown", "id": "c7a4255f", "metadata": {}, "source": [ "Creating a PyTorchSmoothedViT instance with the above code follows many of the general ART patterns with two caveats: \n", "+ The optimizer would (normally) be supplied initialised into the estimator along with a pytorch model. However, here we have not yet created the model, we are just supplying the model architecture name. Hence, here we pass the class into PyTorchDeRandomizedSmoothing with the keyword arguments in optimizer_params which you would normally use to initialise it.\n", "+ The input shape will primiarily determine if the input requires upsampling. The ViT model such as the one loaded is for images of 224 x 224 resolution, thus in our case of using CIFAR data, we will be upsampling it." ] }, { "cell_type": "code", "execution_count": 6, "id": "44975815", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The shape of the ablated image is (10, 4, 224, 224)\n" ] }, { "data": { "text/plain": [ "