{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Clean-Label Feature Collision Attacks on a PyTorch Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will learn how to use ART to run a clean-label feature collision poisoning attack on a neural network trained with PyTorch. We will be training our data on a subset of the CIFAR-10 dataset. The methods described are derived from [this paper](https://arxiv.org/abs/1804.00792) by Shafahi, Huang, et. al. 2018." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "from os.path import abspath\n", "\n", "module_path = os.path.abspath(os.path.join('..'))\n", "if module_path not in sys.path:\n", " sys.path.append(module_path)\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "from art import config\n", "from art.utils import load_dataset, get_file\n", "from art.estimators.classification import PyTorchClassifier\n", "from art.attacks.poisoning import FeatureCollisionAttack\n", "\n", "import numpy as np\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "np.random.seed(301)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(50000, 32, 32, 3)\n", "shape of x_train (1000, 3, 32, 32)\n", "shape of y_train (1000, 10)\n" ] } ], "source": [ "(x_train, y_train), (x_test, y_test), min_, max_ = load_dataset('cifar10')\n", "print(x_train.shape)\n", "x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)\n", "x_test = np.transpose(x_test, (0, 3, 1, 2)).astype(np.float32)\n", "num_samples_train = 1000\n", "num_samples_test = 1000\n", "x_train = x_train[0:num_samples_train]\n", "y_train = y_train[0:num_samples_train]\n", "x_test = x_test[0:num_samples_test]\n", "y_test = y_test[0:num_samples_test]\n", "\n", "class_descr = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "print(\"shape of x_train\",x_train.shape)\n", "print(\"shape of y_train\",y_train.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget -c https://www.dropbox.com/s/ljkld6opyruvn5u/resnet18.pt?dl=0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Model to be Attacked\n", "\n", "In this example, we using a RESNET18 model pretrained on the CIFAR dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Model Definition and pretrained model pulled from: \n", "# https://github.com/huyvnphan/PyTorch_CIFAR10\n", "import torch\n", "import torch.nn as nn\n", "import os\n", "\n", "__all__ = [\n", " \"ResNet\",\n", " \"resnet18\",\n", " \"resnet34\",\n", " \"resnet50\",\n", "]\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return nn.Conv2d(\n", " in_planes,\n", " out_planes,\n", " kernel_size=3,\n", " stride=stride,\n", " padding=dilation,\n", " groups=groups,\n", " bias=False,\n", " dilation=dilation,\n", " )\n", "\n", "\n", "def conv1x1(in_planes, out_planes, stride=1):\n", " \"\"\"1x1 convolution\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n", "\n", "\n", "class BasicBlock(nn.Module):\n", " expansion = 1\n", "\n", " def __init__(\n", " self,\n", " inplanes,\n", " planes,\n", " stride=1,\n", " downsample=None,\n", " groups=1,\n", " base_width=64,\n", " dilation=1,\n", " norm_layer=None,\n", " ):\n", " super(BasicBlock, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " if groups != 1 or base_width != 64:\n", " raise ValueError(\"BasicBlock only supports groups=1 and base_width=64\")\n", " if dilation > 1:\n", " raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n", " # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = norm_layer(planes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = norm_layer(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", "\n", " def __init__(\n", " self,\n", " inplanes,\n", " planes,\n", " stride=1,\n", " downsample=None,\n", " groups=1,\n", " base_width=64,\n", " dilation=1,\n", " norm_layer=None,\n", " ):\n", " super(Bottleneck, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " width = int(planes * (base_width / 64.0)) * groups\n", " # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv1x1(inplanes, width)\n", " self.bn1 = norm_layer(width)\n", " self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", " self.bn2 = norm_layer(width)\n", " self.conv3 = conv1x1(width, planes * self.expansion)\n", " self.bn3 = norm_layer(planes * self.expansion)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "class ResNet(nn.Module):\n", " def __init__(\n", " self,\n", " block,\n", " layers,\n", " num_classes=10,\n", " zero_init_residual=False,\n", " groups=1,\n", " width_per_group=64,\n", " replace_stride_with_dilation=None,\n", " norm_layer=None,\n", " ):\n", " super(ResNet, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " self._norm_layer = norm_layer\n", "\n", " self.inplanes = 64\n", " self.dilation = 1\n", " if replace_stride_with_dilation is None:\n", " # each element in the tuple indicates if we should replace\n", " # the 2x2 stride with a dilated convolution instead\n", " replace_stride_with_dilation = [False, False, False]\n", " if len(replace_stride_with_dilation) != 3:\n", " raise ValueError(\n", " \"replace_stride_with_dilation should be None \"\n", " \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation)\n", " )\n", " self.groups = groups\n", " self.base_width = width_per_group\n", "\n", " # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1\n", " self.conv1 = nn.Conv2d(\n", " 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False\n", " )\n", " # END\n", "\n", " self.bn1 = norm_layer(self.inplanes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(\n", " block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]\n", " )\n", " self.layer3 = self._make_layer(\n", " block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]\n", " )\n", " self.layer4 = self._make_layer(\n", " block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]\n", " )\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc = nn.Linear(512 * block.expansion, num_classes)\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n", " elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n", " nn.init.constant_(m.weight, 1)\n", " nn.init.constant_(m.bias, 0)\n", "\n", " # Zero-initialize the last BN in each residual branch,\n", " # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n", " # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n", " if zero_init_residual:\n", " for m in self.modules():\n", " if isinstance(m, Bottleneck):\n", " nn.init.constant_(m.bn3.weight, 0)\n", " elif isinstance(m, BasicBlock):\n", " nn.init.constant_(m.bn2.weight, 0)\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n", " norm_layer = self._norm_layer\n", " downsample = None\n", " previous_dilation = self.dilation\n", " if dilate:\n", " self.dilation *= stride\n", " stride = 1\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " conv1x1(self.inplanes, planes * block.expansion, stride),\n", " norm_layer(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(\n", " block(\n", " self.inplanes,\n", " planes,\n", " stride,\n", " downsample,\n", " self.groups,\n", " self.base_width,\n", " previous_dilation,\n", " norm_layer,\n", " )\n", " )\n", " self.inplanes = planes * block.expansion\n", " for _ in range(1, blocks):\n", " layers.append(\n", " block(\n", " self.inplanes,\n", " planes,\n", " groups=self.groups,\n", " base_width=self.base_width,\n", " dilation=self.dilation,\n", " norm_layer=norm_layer,\n", " )\n", " )\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = x.reshape(x.size(0), -1)\n", " x = self.fc(x)\n", "\n", " return x\n", "\n", "\n", "def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):\n", " model = ResNet(block, layers, **kwargs)\n", " if pretrained:\n", " # Download the model state_dict from the link: and run your code\n", " state_dict = torch.load(\n", " 'resnet18.pt?dl=0', map_location=device\n", " )\n", " model.load_state_dict(state_dict)\n", " return model\n", "\n", "\n", "def resnet18(pretrained=False, progress=True, device=\"cpu\", **kwargs):\n", " \"\"\"Constructs a ResNet-18 model.\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet(\n", " \"resnet18\", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs\n", " )\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "# Pretrained model\n", "classifier_model = resnet18(pretrained=True)\n", "classifier_model.eval() # for evaluation\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(classifier_model.parameters(), lr=0.0001)\n", "classifier = PyTorchClassifier(clip_values=(min_, max_), model=classifier_model, \n", " preprocessing=((0.4914, 0.4822, 0.4465),(0.2471, 0.2435, 0.2616)),nb_classes=10,input_shape=(3,32,32),loss=criterion,\n", " optimizer=optimizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Choose Target Image from Test Set" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape of target_instance (1, 3, 32, 32)\n", "true_class: bird\n", "predicted_class: bird\n", "avgpool\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAb2ElEQVR4nO2da2yc55Xf/2eGM+JVvOlCWpQsWZYiyzdZZrTOxamTtImTLmoHbYJN0cAfgtV+2AANsP1gpECTfkuLJot8WARVNsY6RZrE2CSN0brpukYT57Jrm0qsiy1fJOpGiiIlUbyIFC8zc/qBI0D2Pv+H1HA4ZPz8fwDB4fOf877PvPOe9x0+Z8455u4QQrz3yaz2BIQQtUHOLkQiyNmFSAQ5uxCJIGcXIhHk7EIkQt1yjM3sUQDfApAF8Nfu/vXY8/PrO7xx41ai8hCgGbGIRA0zzGiRfUUjkezSGLOJaNk6fvgto+vwzbyXI8T0pVXwmqcunsHM+OXgyV+xs5tZFsBfAfhnAAYAvGJmz7r768ymceNWPPz158j2inRf2Wz4xPdigdrkczmqxb5bUCryeVg+PI9SiduATxGt7Z1UyzU2Ua0EfiGLXeKqTfRcrMg7+QUu9p5V8l0Rd36kKrx2V3inAErEzEslvjkLb+9/H+ytYAaLcwDASXfvd/c5AD8E8NgytieEWEGW4+xbAJy/6e+B8pgQYg2yHGcPfQ76Rx9IzOygmfWZWd/cxJVl7E4IsRyW4+wDAG5ebesBcOHdT3L3Q+7e6+69+fX8f1QhxMqyHGd/BcAuM9thZnkAfwLg2epMSwhRbSpejXf3gpl9CcD/wULo7Sl3fy1mk83l0N51G9lgZBXcwsuVsStVNPRWWVQOsPDqaJbMDwCykQXVUmQeReeGC4GQ1afSw8iNIlZrJPZmsahA7FXH3k8SGaj2K15WnN3dnwMQjqUJIdYU+uaGEIkgZxciEeTsQiSCnF2IRJCzC5EIy1qNrwTPkjCD83BShsR4SpHwWixSE02IixIOn2SiqW2xEGAsvBZLdqldukulBUkzFWTtRd+zyGsusUwSVBoCjMwj8rosYliKxVnp9iKwaUSMdGcXIhHk7EIkgpxdiESQswuRCHJ2IRKh5qvxkTXViFV4ldMiK/iIlB2qGFIKKG4T0TKRVeTKFvGrTiwqsFaIJj2RJf7Y4rjH3rToriIRlGofxgq2pzu7EIkgZxciEeTsQiSCnF2IRJCzC5EIcnYhEqGmoTdDpINSJAuCfuc/Uvst2k4q1lElFtIgYqVXzNg8IlG5eDSPvO5YtC6WUJStsBCak/igR2aSiXVNKXGtGAuHEbNIAyIsUoiQKtGIaOwkIYaVnKfxc0MIkQRydiESQc4uRCLI2YVIBDm7EIkgZxciEZYVejOzMwAmARQBFNydd4IvQ0Nv4BlsLARhsWBHJCwX7QwVEVlYK1qDLkIsey02j1jGVoYckzoU+PYiAZtCrIZeiZ8+LIwWu7t4lqvrjM9/vfPuwO35seB4o/Ptjcx3UO1CcSPVMizOBx6KBIBCJnwcM7HwYAVZb9WIs3/U3S9XYTtCiBVEH+OFSITlOrsD+DszO2xmB6sxISHEyrDcj/EfcvcLZrYJwPNm9oa7v3jzE8oXgYMA0LR52zJ3J4SolGXd2d39Qvn3CICfAjgQeM4hd+919976Nr64IYRYWSp2djNrMrOWG48BfALA8WpNTAhRXZbzMX4zgJ+WQ0R1AP67u/+8KrNas1SYAlYBuUhoKJZfNV2XC45b5LqeK/EYT13kNa/DNNfmZ8PbK85Rm6Lz7bUUL1Ftd9MQ1e6sD4fl5ub48XhlkIeBZ/J7qXa15X6qzWMd1TIk49OjLcBunYqd3d37AfBXJ4RYUyj0JkQiyNmFSAQ5uxCJIGcXIhHk7EIkQs17vTEqqPNY+b7iaW9cusVxIF5IM2YZ6WK3yDbD12/L8H01R8Jh7RPn+a4m36bS7KXXguONPkVtGiJnY0Oehxu9hc9/uq4+OD5TCo8DQP7KKaq9r2OUasdy/EtjE/U7qJbxcOhzziKZoBVUnNSdXYhEkLMLkQhydiESQc4uRCLI2YVIhDWzGh9fjmfDVV6mX2QaXI3Vu6tsjtctnNCysFH+tnWXwhXCumf4qvr8+aNUy0zzlemWuutUuzx0Ojh+dXiS2ow2NlKttbOJavkuvrI+XAiv1G/bxG0evo/XoLs0m6fa9SF+HE/neLLRRHu4zoNlIjGZCk4r3dmFSAQ5uxCJIGcXIhHk7EIkgpxdiESQswuRCDUPvbEgVTTyRoxiyR1Vz55BhYkwkXlkI+2fGos84WJT6U2qdU2FE1BaZi9Sm2sFXvst28BDgHV1/PTJ77g9ON7aSU1wbWaeaoXSDNWyGR6WO9sfDh2eOT5BbXbvf5BqwzkeDptcx0OHkbcaGSdqpIVZJbE33dmFSAQ5uxCJIGcXIhHk7EIkgpxdiESQswuRCIuG3szsKQB/DGDE3e8pj3UA+BGA7QDOAPicu19dbFtuQJFELuoKPMyQZaGJSPihmIlcxyIRjVi7ozxJXJrLxuYRboMEAN3XB6i2N8PDa5sbeSZaXS6cVTYzy7Ou2jp47bT5CW43cpm3Xbo2Ew7njU3wrLehtwepNlfH68w1tt9NtWkLh8OuzfDTdbj/GtUudW2lWnbLw1zL8iw7R/gYs7ZQQGWR5aXc2f8GwKPvGnsSwAvuvgvAC+W/hRBrmEWdvdxv/d3f8HgMwNPlx08DeLy60xJCVJtK/2ff7O5DAFD+val6UxJCrAQrvkBnZgfNrM/M+mbHeNtdIcTKUqmzD5tZNwCUf4+wJ7r7IXfvdffedW18IUgIsbJU6uzPAnii/PgJAD+rznSEECvFUkJvPwDwCIANZjYA4KsAvg7gGTP7IoBzAD67lJ0ZgBxN8Ln1DLa6UiQ0EckzimWiZUo886ouF9YaizyDqjHSSmjzzBGqravjWW+TV3nRw8tXwnajV/m/UNev83BYochDb7dtCWe2AcC2rXcGx7s28felOVJwMtsSiZfmeSbawPnw6x6bWEdt6rvuo9q6SHitWB/Jeivy+ZdK4fnHQm+8CCtnUWd3988T6eOL2Qoh1g76Bp0QiSBnFyIR5OxCJIKcXYhEkLMLkQg1LThpDmRJNKEQsWPBsFi4LhPR6rM8nNS6jmdXFYd+HxwvnXiJ2tjFC1Trn+RFIF8Y5hlgl0fC/dwAoLkl/JY+0Lud2ux7cC/Vtu+5h2qbNvdQzUiW1+w8f6d39N5BtVOneBbgz5/9NdX+4bfh49+455PU5p/0/muqTeTbqZaNhoI5NBgZCRGvVNabEOI9gJxdiESQswuRCHJ2IRJBzi5EIsjZhUiEmobeFgpOhsNeORpgA5rIJSlX4MUcbWqMarnp01RrL52nWmkkrJVyvFdaw308VOOTPEsqf5z3NvvggR1Uu29fOHy19fZWatPQyDPAYLzX22CkGMkceTtn5nnW28VL41Q7/FtenLPvMM86HLi6ITj+/vs+QW2mmvh7VoxkAdY517zE76tmJGtvFQpOCiHeA8jZhUgEObsQiSBnFyIR5OxCJEJtE2FQQq4UTjSZPcXrsfW/8mJwfOz0CWozHVkpnpnj9d0acnyF/+OP3B8c39gTXvEFgDf6eULLNVIfDQDu3s1X3O9+/y6qDY+EEz8u/f4KtanL8dX4bG6Marl8M9VaO8KtBEbHeOTiN7/i7+fhX71OtVPn+Dbbdn8gON6z//3UpkDaMQFAXaS2Yeze6RHNyNJ6pAJdRejOLkQiyNmFSAQ5uxCJIGcXIhHk7EIkgpxdiERYSvunpwD8MYARd7+nPPY1AH8K4Ebs6Cvu/tyieysU4GPhENCR5/4XNTt7+O+D462beMhrfde9VOva8T6qdWxuo9pANhx2Gbx2jdrMTgxRbWyc12Mb7TtDtV/85g2qDQyEk3W23dZFbe7cwcN8qOcJOYVIMsbU+OHg+PDQGLU5NzBMtYkZHojacv8fUW3/v/w3wfHGjliyC99XxnmrqWIk2SUWRqskqaUSlnJn/xsAjwbG/9Ld95V/Fnd0IcSqsqizu/uLAPi3UIQQfxAs53/2L5nZUTN7ysz4ZyIhxJqgUmf/NoCdAPYBGALwDfZEMztoZn1m1jczoQ8IQqwWFTm7uw+7e9HdSwC+A+BA5LmH3L3X3Xvr13dUOk8hxDKpyNnNrPumPz8D4Hh1piOEWCmWEnr7AYBHAGwwswEAXwXwiJntw0JE4QyAP1va7hxFUmtu1wc/Tq32furzwfH6ns3UJt/IM7Jaivxlb58Jt3gCAMuE67Gda+Shq2ukDRIAzE/xmmvXBk9SbeIs15obXw6OXz7FX1f2Cr9WX57jWV7TkVZOhYlwXbjpeV7Tztt5KHXLQ7xm3H3/4gtUa+4J1+SbqzAWVozUhfNMJCMu0o4sqlWRRZ3d3UOe9t0VmIsQYgXRN+iESAQ5uxCJIGcXIhHk7EIkgpxdiESoacFJ1OWQ6QwXItzc2UPNCtlwuGaG1wXE1Wy4sCUAzJCilwCw9U0eovqjznBG2XgTDxuObX6Iavn1nVRr27CRah37wkUUAeCOfeH9Xf3lT6jNdrtMtVdf/Q3VLl5uodrc1nBRzMae8PsPAD37H6Ha1gc/RjXv4CG7KXKOxE58i4TXEIuSxeyihrVBd3YhEkHOLkQiyNmFSAQ5uxCJIGcXIhHk7EIkQm1Db8jAPJwFNlnioYl1hXB21W2zkT5qeR7imWxpoNqlCX79K1wJ9xvbs4XbjDfwzLwrHXuolp8OZwcCAIwXPcxsCmd5tXwynDkIAM3rpqj20R28r9zZIg95jXeHC342da6nNtkmXu9ghoRfgXgmWp2FM9EsmmlW/TBZNGJX9b2F0Z1diESQswuRCHJ2IRJBzi5EIsjZhUiEmq7GmwO5+fBKciEyk+aZcD2zvZP/QG2mi7updrTpAaqdW8/ryf38F+EEmsc736Y2g6OvUO1qa3jlHAByGb4aX8zy1fhprAuO123YRm2Gp3hiUMb7qXbfR+6m2vHs1uD47BSf+7zzGm4GruVLPCMq4+FITgl8db8EPseKWQPL8bqzC5EIcnYhEkHOLkQiyNmFSAQ5uxCJIGcXIhGW0v5pK4DvAegCUAJwyN2/ZWYdAH4EYDsWWkB9zt2vLra9Irm8FLI8fFI3dyU43lU8R20a6i5R7eQVnpxy/a57qPbGr8KJHxdH+L7u2niWasPXeFhrdP2dVIvVOsuQEI8V+PEt1UfCUC1NVGt+5YdU69354eD4G/XvozbjdTxJZr7ET1WP3LMKyBMlvfvcUl5xAcBfuPtdAB4C8OdmthfAkwBecPddAF4o/y2EWKMs6uzuPuTuvys/ngRwAsAWAI8BeLr8tKcBPL5CcxRCVIFb+ixjZtsBPADgJQCb3X0IWLggAOAJ5EKIVWfJzm5mzQB+DODL7h7+/mrY7qCZ9ZlZ3/Ux/r+tEGJlWZKzm1kOC47+fXe/0W1g2My6y3o3gJGQrbsfcvded+9taOOND4QQK8uizm5mhoV+7Cfc/Zs3Sc8CeKL8+AkAP6v+9IQQ1WIpWW8fAvAFAMfM7NXy2FcAfB3AM2b2RQDnAHx2sQ2VDJghl5ccT2pCcT6cuTQ4HPwwAQB4ZDevM3fb8EtUm97zGJ/HBx8Njr9+5HvU5p/v5uHB7RMvU20qz7PvPM9DZfNZki0XybpqMV7Lb1/nJNXed5qHFedJLb/WxjFqc/paN9ca7qLaaK6RavXhUwdzmUjdusi56BXXrrv1unbRLlQVlMlb1Nnd/dfgM+VNzoQQa4r0vlkgRKLI2YVIBDm7EIkgZxciEeTsQiRCjds/LYTfQuRZOhyAuVw4G+rEIC/K+OEeHrfY429S7e2Zcaq17nkkbPP7/0ttLg1eoNodu05TbWD6FNWGG3nbqCw5wPORUE3P6DDVdo6/QbWGpsh7NhvOVGw+9Rtqc3emh2rjbeFCmgAw1bmXanWl8Bx5DuBi4bUKqVWPpwi6swuRCHJ2IRJBzi5EIsjZhUgEObsQiSBnFyIRah56Y9k685HLTnZ9Z3B8opH3c7t66VWq3dHNd7Z+nBeBLGx4MDg+du/HqM2Rk39NtU/uCoenAODOuWNUu1q4nWpFCxeILBlP5Zq/ep1qkxd4Hzu0TlOp/81woZLrV2aozYP7+a7aLv+Oaleaw33lAOB6rjU4Hkl6i3SV+8NGd3YhEkHOLkQiyNmFSAQ5uxCJIGcXIhFqvxrv4bXOUmR1dL4u3MLHesKr4wDw5lu/pNpH7+RJFXcN8SSZwx33B8frH/gotek/zecxevEy1e7dxFfBB8dPUu1c293B8bzzt/pMhte0az0/QLXd63lrqInx8Ar/6bd5h7CdO3jS0JYNvM7c2eu8zl+hPtzOy6KZMBVJ8aJxkRp0UbMqoju7EIkgZxciEeTsQiSCnF2IRJCzC5EIcnYhEmHR0JuZbQXwPQBdWMgROOTu3zKzrwH4UwA3Mh6+4u7PRbcFIEtqgqHE0w9oDkfPLmpz4o12qn3gWpZqu/wM1V6fCrc7mu/YSW2mdvKmOcf6n6HaJzfyJJnu6aNUG24LJwdNG3+rJ9bfRrXxfDgJCQCy4IkwbS0twfFLm3hn7+PDU1Tbt2mMahsmeehtojXcNspL/HjEStBZJE4WD8vFxNqwlDh7AcBfuPvvzKwFwGEze76s/aW7/5eVm54QolospdfbEICh8uNJMzsBYMtKT0wIUV1u6X92M9sO4AEAN9qgfsnMjprZU2bGPzcLIVadJTu7mTUD+DGAL7v7BIBvA9gJYB8W7vzfIHYHzazPzPquXw0XNBBCrDxLcnYzy2HB0b/v7j8BAHcfdveiu5cAfAfAgZCtux9y9153721o31iteQshbpFFnd3MDMB3AZxw92/eNN5909M+A+B49acnhKgWS1mN/xCALwA4Zmavlse+AuDzZrYPC0GFMwD+bNEtOWAkBBELaWSKYa3UHK4vBgCDnbxF0uDbr1Ft404e8mq7Gs42G2/jobfMnoeodr7/F1QbnebZYVuaeCbaa+dfD44Xtu6jNo35LqpN5TZQbV2e/1uWz4fjpfm9PFPx4jBvNdVeKFBtx8L6cZCBqXA7r1IkpDib5ecia18G8HMbACwiMsVj6XCsmGOEpazG/xrh/LxoTF0IsbbQN+iESAQ5uxCJIGcXIhHk7EIkgpxdiESobcFJ4xlFpUgogSW9eYZnr+W2PEy1U6+9RLVNO3iIZ+atI8HxUtcHqI1t4mkE13vCBSwB4Gz//6Daztb1VFv3wn8Njk/e/hi1aXz4U1Qr5OqpVsyEM9sAoLElnMHWmm+jNmNjO6g2Ms4LcG7dykOR7dfOB8cvbeChtwx4Ncr5DL8/Zor8HK6LNJVyEpaLhfng4XnEkut0ZxciEeTsQiSCnF2IRJCzC5EIcnYhEkHOLkQi1DT05uCht4q6ZEX6dXV29lBtZo7HNCYvj3Ht1Ing+NyeQWrTsD1c8BAAZm+/l2pvvvQ81e70Gar1towEx3/9t39FbeZ3bKPattZuqmVsgmptDeH+cQ3g4dLJbfuodmaQF9ns3TFJtS1T/cHxEQv3xAOALAlrAUBjgYfQipGz2CK93nKkCGukBit1ili0Tnd2IRJBzi5EIsjZhUgEObsQiSBnFyIR5OxCJEJts96AinpescJ72UhTrvHLvHDkWyd4ltRdm3kmV3vpcnD87OsvU5uWnkjoLaKNvs4z4iYu88KMu+8NZ6n1HuH90I69/DOq1d/Oi1GWijwEWIf54Pi66zxcV9jxYaqdO99BtQPTvOfcFj8VHD8+HQ5RAsBUI89UrI+E3hApVDkbyerMsExQvqdK6k3qzi5EKsjZhUgEObsQiSBnFyIR5OxCJMKiq/FmVg/gRQDrys//W3f/qpl1APgRgO1YaP/0OXfnPYsAwAHW0aaSVjc8pQIYHeNTKQyGWwIBwNBFvgZ617ZwY8pj/a9Qm9LMZ7jWwttXzW3eT7Xh4beo1vH+8FF5/F/xxI/to3w1+/pIOJEEADId/Fg1NIajAs1XeKumYnMj1Ybb9lJtbPS3VOvqCkdQOq+epTaTzXw1vpDjy+AWWT6Pray7kYyubGTJndTCW24NulkAH3P3+7HQnvlRM3sIwJMAXnD3XQBeKP8thFijLOrsvsC18p+58o8DeAzA0+XxpwE8vhITFEJUh6X2Z8+WO7iOAHje3V8CsNndhwCg/HvTis1SCLFsluTs7l50930AegAcMLN7lroDMztoZn1m1jczxlv8CiFWlltajXf3MQC/APAogGEz6waA8u/g9w/d/ZC797p7b31beIFLCLHyLOrsZrbRzNrKjxsA/FMAbwB4FsAT5ac9AYB/wVoIseosJRGmG8DTZpbFwsXhGXf/n2b29wCeMbMvAjgH4LMrOM8gsXDGbKRAndU1UW1iepZqB/aGw0l3DI5Sm/OXwokYANDR/ADV5rbdSbWTJ8P13QBgz5Xw625qDSemAMAH2vmxeqnvAtWKzltlNXWE2yu1DoxRm3rwxJrpTbxeX/+ZX1Jtf1c4rLh59HVqM9i+h2ozGf6a18/ze2cBvI2WYy48Hkn0KrHeUEX+Xi7q7O5+FMA/Oivd/QqAjy9mL4RYG+gbdEIkgpxdiESQswuRCHJ2IRJBzi5EIlg026zaOzO7BOBGutEGAOGUpNqiebwTzeOd/KHN43Z3D357rabO/o4dm/W5e++q7Fzz0DwSnIc+xguRCHJ2IRJhNZ390Cru+2Y0j3eiebyT98w8Vu1/diFEbdHHeCESYVWc3cweNbM3zeykma1a7TozO2Nmx8zsVTPrq+F+nzKzETM7ftNYh5k9b2Zvl3+3r9I8vmZmg+Vj8qqZfboG89hqZv/PzE6Y2Wtm9m/L4zU9JpF51PSYmFm9mb1sZkfK8/iP5fHlHQ93r+kPForCngJwB4A8gCMA9tZ6HuW5nAGwYRX2+xEA+wEcv2nsPwN4svz4SQD/aZXm8TUA/67Gx6MbwP7y4xYAbwHYW+tjEplHTY8JAAPQXH6cA/ASgIeWezxW485+AMBJd+939zkAP8RC8cpkcPcXAbw7Cb7mBTzJPGqOuw+5++/KjycBnACwBTU+JpF51BRfoOpFXlfD2bcAOH/T3wNYhQNaxgH8nZkdNrODqzSHG6ylAp5fMrOj5Y/5K/7vxM2Y2XYs1E9Y1aKm75oHUONjshJFXlfD2UMlNlYrJPAhd98P4FMA/tzMPrJK81hLfBvATiz0CBgC8I1a7djMmgH8GMCX3Z33dq79PGp+THwZRV4Zq+HsAwC23vR3DwBe+2gFcfcL5d8jAH6KhX8xVoslFfBcadx9uHyilQB8BzU6JmaWw4KDfd/df1IervkxCc1jtY5Jed9juMUir4zVcPZXAOwysx1mlgfwJ1goXllTzKzJzFpuPAbwCQDH41Yrypoo4HnjZCrzGdTgmJiZAfgugBPu/s2bpJoeEzaPWh+TFSvyWqsVxnetNn4aCyudpwD8+1Wawx1YiAQcAfBaLecB4AdY+Dg4j4VPOl8E0ImFNlpvl393rNI8/huAYwCOlk+u7hrM48NY+FfuKIBXyz+frvUxicyjpscEwH0Afl/e33EA/6E8vqzjoW/QCZEI+gadEIkgZxciEeTsQiSCnF2IRJCzC5EIcnYhEkHOLkQiyNmFSIT/D/yo5jyoIRevAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "target_class = \"bird\" # one of ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "target_label = np.zeros(len(class_descr))\n", "target_label[class_descr.index(target_class)] = 1\n", "target_instance = np.expand_dims(x_test[np.argmax(y_test, axis=1) == class_descr.index(target_class)][3], axis=0)\n", "img_plot = np.transpose(target_instance[0],(1,2,0))\n", "fig = plt.imshow(img_plot)\n", "print(\"shape of target_instance\",target_instance.shape)\n", "print('true_class: ' + target_class)\n", "print('predicted_class: ' + class_descr[np.argmax(classifier.predict(target_instance), axis=1)[0]])\n", "\n", "feature_layer = classifier.layer_names[-2]\n", "print(feature_layer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Poison Training Images to Misclassify Test\n", "\n", "The attacker wants to make it such that whenever a prediction is made on this particular cat the output will be a horse." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New test data to be poisoned (10 images):\n", "Correctly classified: 9\n", "Incorrectly classified: 1\n" ] } ], "source": [ "base_class = \"frog\" # one of ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", "base_idxs = np.argmax(y_test, axis=1) == class_descr.index(base_class)\n", "base_instances = np.copy(x_test[base_idxs][:10])\n", "base_labels = y_test[base_idxs][:10]\n", "x_test_pred = np.argmax(classifier.predict(base_instances), axis=1)\n", "nb_correct_pred = np.sum(x_test_pred == np.argmax(base_labels, axis=1))\n", "\n", "print(\"New test data to be poisoned (10 images):\")\n", "print(\"Correctly classified: {}\".format(nb_correct_pred))\n", "print(\"Incorrectly classified: {}\".format(10-nb_correct_pred))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10,10))\n", "for i in range(0, 9):\n", " pred_label, true_label = class_descr[x_test_pred[i]], class_descr[np.argmax(base_labels[i])]\n", " plt.subplot(330 + 1 + i)\n", " fig=plt.imshow(np.transpose(base_instances[i],(1,2,0)))\n", " fig.axes.get_xaxis().set_visible(False)\n", " fig.axes.get_yaxis().set_visible(False)\n", " fig.axes.text(0.5, -0.1, pred_label + \" (\" + true_label + \")\", fontsize=12, transform=fig.axes.transAxes, \n", " horizontalalignment='center')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The captions on the images can be read: `predicted label (true label)`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating Poison Frogs" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4dccde13287d43e080eee5249bc8bba4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Feature collision: 0%| | 0/10 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "poison_pred = np.argmax(classifier.predict(poison), axis=1)\n", "plt.figure(figsize=(10,10))\n", "for i in range(0, 9):\n", " pred_label, true_label = class_descr[poison_pred[i]], class_descr[np.argmax(poison_labels[i])]\n", " plt.subplot(330 + 1 + i)\n", " fig=plt.imshow(np.transpose(poison[i],(1,2,0)))\n", " fig.axes.get_xaxis().set_visible(False)\n", " fig.axes.get_yaxis().set_visible(False)\n", " fig.axes.text(0.5, -0.1, pred_label + \" (\" + true_label + \")\", fontsize=12, transform=fig.axes.transAxes, \n", " horizontalalignment='center')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice how the network classifies most of theses poison examples as frogs, and it's not incorrect to do so. The examples look mostly froggy. A slight watermark of the target instance is also added to push the poisons closer to the target class in feature space." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training with Poison Images" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import torch.optim as optim\n", "adv_train = np.vstack([x_train, poison])\n", "adv_labels = np.vstack([y_train, poison_labels])\n", "classifier_model.train()\n", "classifier.fit(adv_train, adv_labels, nb_epochs=20, batch_size=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fooled Network Misclassifies Bird" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "true_class: bird\n", "predicted_class: frog\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.imshow(np.transpose(target_instance[0],(1,2,0)))\n", "\n", "print('true_class: ' + target_class)\n", "print('predicted_class: ' + class_descr[np.argmax(classifier.predict(target_instance), axis=1)[0]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These attacks allow adversaries who can poison your dataset the ability to mislabel any particular target instance of their choosing without manipulating labels." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 2 }