Diving deep: convolution kernel minimizer

Date: May 26, 2020

Table of contents:

Today, we will do another experiment. We will set up a good-enough CNN to classify the CIFAR-10 dataset, which is like MNIST, but it has 3 channels, a little bit larger, and is arguably harder. Then, we will initialize a dummy image with the expected dimensions. But the difference now is that we will track gradients for the image itself, in order to slowly modify the image so that the output of a channel in some deep part of the network is minimized.

Setting up

Let's get started by importing a bunch of stuff:

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
from PIL import Image

Then the CIFAR-10 dataset:

In [2]:
means = torch.Tensor([0.4914, 0.4822, 0.4465])
stds = torch.Tensor([0.2470, 0.2435, 0.2616])
transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 
                                 transforms.RandomRotation(10), 
                                 transforms.ToTensor(), 
                                 transforms.Normalize(means, stds)])
dl = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True, 
                                                  train=True, transform=transforms), 
                                 batch_size=1000, 
                                 shuffle=True)
dl_test = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True, 
                                                       train=False, transform=transforms), 
                                      batch_size=100, 
                                      shuffle=True)
stds = stds.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
means = means.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
Files already downloaded and verified
Files already downloaded and verified

Let's look at a few of them:

In [3]:
categories = ["plane", "auto", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
imgs, labels = next(iter(dl)); plt.figure(num=None, figsize=(10, 3), dpi=350)
for i in range(20):
    plt.subplot(2, 10, i+1); plt.imshow((imgs[i] * stds + means).permute(1, 2, 0))
    plt.title(categories[labels[i]]); plt.axis("off")

Let's see the shape of the dataset:

In [22]:
imgs.shape, labels.shape
Out[22]:
(torch.Size([1000, 3, 32, 32]), torch.Size([1000]))

And let's now define a model. This model was originally written in a previous post, so we're just gonna copy it here:

In [20]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # 32x32 -> 32x32
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1) # 32x32 -> 16x16
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # 16x16 -> 16x16
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1) # 16x16 -> 8x8
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1) # 8x8 -> 4x4
        self.pool = nn.MaxPool2d(2); self.relu = nn.ReLU(); self.logSoftmax = nn.LogSoftmax(1)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(128 * 4 * 4, 300); self.fc2 = nn.Linear(300, 10)
        self.dropout = nn.Dropout(0.5)
        self.times, self.losses, self.accuracies = [], [], []
    def forward(self, x):
        x = self.batchnorm1(self.pool(self.relu(self.conv2(self.relu(self.conv1(x))))))
        x = self.batchnorm2(self.pool(self.relu(self.conv4(self.relu(self.conv3(x))))))
        x = self.pool(self.relu(self.conv5(x)))
        x = self.dropout(x.contiguous().view(-1, 128 * 4 * 4))
        x = self.dropout(self.relu(self.fc1(x)))
        return self.logSoftmax(self.fc2(x))
    def fit(self, epochs):
        optimizer = optim.Adam(self.parameters(), lr=0.003, weight_decay=1e-5)
        count, lossFunction = 0, nn.NLLLoss()
        lastTime, initialTime = (self.times[-1] if len(self.times) > 0 else 0), time.time()
        for epoch in range(epochs):
            for imgs, labels in dl:
                count += 1; optimizer.zero_grad(); imgs, labels = imgs.cuda(), labels.cuda()
                loss = lossFunction(self(imgs), labels); loss.backward(); optimizer.step()
                if count % 30 == 0:
                    self.eval()
                    test_imgs, test_labels = next(iter(dl_test));self.losses.append(loss.item())
                    self.accuracies.append((torch.argmax(self(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum())
                    self.times.append(lastTime + time.time()-initialTime); self.train()
                    print(f"\rProgress: {np.round(100*count/(epochs*len(dl)))}%, loss: {self.losses[-1]}, accuracy: {self.accuracies[-1]}/100    ", end="")
        return torch.Tensor(self.losses), torch.Tensor(self.accuracies), torch.Tensor(self.times)
net = Net().cuda()
net.load_state_dict(torch.load("models/cnn-standard.pth"))
#losses, accuracies, times = net.fit(30); torch.save(net.state_dict(), "models/cnn-standard.pth")
Out[20]:
<All keys matched successfully>

Also note that we're using a model that I have already trained for several minutes, and we're just gonna load it up, to make the results across experiments more repeatable. However, you can always train it on your own, with the commented line I included. Let's check that the model is still doing fine:

In [527]:
totalAccuracy = 0; net.eval()
for test_imgs, test_labels in dl_test:
    totalAccuracy += (torch.argmax(net(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum()
totalAccuracy/len(dl_test)
Out[527]:
tensor(81, device='cuda:0')

82% accuracy, which is pretty good. For all intents and purposes, I think that's enough to gain insights. We can do better by really mess around with the architecture, of course, but the iteration speed won't be as fast then.

Below is something that will graph for us a bunch of kernel outputs. This is also taken directly from that previous post. The idea is that if we can see what's going on inside, we may gain some insights:

In [518]:
# expects x of shape (#samples, 3, 32, 32)
def pickOut(self, x, convLayer=3):
    x = self.relu(self.conv1(x))
    if convLayer == 1: return x
    x = self.relu(self.conv2(x))
    if convLayer == 2: return x
    x = self.batchnorm1(self.pool(x))
    x = self.relu(self.conv3(x))
    if convLayer == 3: return x
    x = self.relu(self.conv4(x))
    if convLayer == 4: return x
    x = self.batchnorm2(self.pool(x))
    x = self.relu(self.conv5(x))
    if convLayer == 5: return x
    x = self.pool(x)
    x = self.dropout(x.contiguous().view(-1, 128 * 4 * 4))
    x = self.dropout(self.relu(self.fc1(x)))
    return self.logSoftmax(self.fc2(x))
Net.pickOut = pickOut

# expects image of shape (1, 3, 32, 32)
def displayConvOutputs(self, imgs):
    for convLayer in range(5):
        output = self.pickOut(imgs.cuda(), convLayer + 1).cpu().detach()
        print(f"Conv layer: {convLayer + 1}")
        dim = output.shape[1]; plt.figure(num=None, figsize=(10, 4/dim*2.5*16), dpi=350)
        for i in range(dim):
            plt.subplot(4, dim/4, i+1); plt.axis("off"); plt.imshow(output[0][i])
        plt.show()
Net.displayConvOutputs = displayConvOutputs

# expects image of shape (1, 3, 32, 32)
def graphPredictions(self, img, orig, means=0, stds=1):
    plt.figure(num=None, figsize=(10, 3), dpi=350)
    plt.subplot(1, 3, 1); plt.bar(categories, torch.exp(self(img.cuda())[0]).detach().cpu())
    plt.xticks(rotation='vertical')
    plt.subplot(1, 3, 2); plt.imshow((img[0].cpu() * stds + means).permute(1, 2, 0).detach())
    plt.subplot(1, 3, 3)#; plt.imshow((orig[0].cpu() * stds + means).permute(1, 2, 0).detach())
    plt.imshow((torch.abs(img[0].cpu()-orig[0]) * stds + means).permute(1, 2, 0).detach())
    plt.show()
Net.graphPredictions = graphPredictions

# expects image of shape (1, 3, 32, 32)
def predict(self, img):
    return categories[torch.argmax(self(img.cuda()), dim=1)[0]]
Net.predict = predict

There's also a graphPredictions() function, which will display the probability of each classes of an image, and will also display the image itself and the difference between that and the original image. There's an optional means and stds to denormalize the image so that it can be displayed.

Below are the guts of today's post. maximizeForLayer will try to get a real image, track its gradients and get Adam to take control over its internals. Then, it will gradually try to null out all outputs of the entire convolution layer. It does this at increasing strides from 32, to 16, 8, 4, 2 and then 1. This means that if a convolution layer has 128 channels (only layer 5), then Adam will try to null out $\frac{128}{32}=4$ kernel outputs at once the first time, then 8, then 16, 32, 64 and then all 128 kernel outputs. Then the original image should be deformed in a way as to null everything of that convolution layer, and that should be interesting.

To make the code looks cleaner, we will be defining the function maximizeForStride().

In [508]:
def maximizeForStride(self, img, originalImage, convLayer, stride, optimizer):
    losses = []
    #display(originalImage)
    for i in range(1000):
        optimizer.zero_grad(); loss = net.pickOut(img, convLayer)[0][::stride].sum()
        loss.backward(); losses.append(loss.item()); optimizer.step()
    net.graphPredictions(img, originalImage, means, stds); return losses
Net.maximizeForStride = maximizeForStride

def maximizeForLayer(self, convLayer):
    imgs, labels = next(iter(dl)); orig = imgs[0:1]; img = orig.cuda().requires_grad_(True)
    print(f"Category: {categories[labels[0]]}")
    optimizer = optim.Adam([img], lr=0.003); losses = []; print("Original:")
    net.graphPredictions(img, orig, means, stds)
    for i in range(6):
        print(f"Layer: {convLayer}, stride: {int(32/2**i)}")
        losses.extend(self.maximizeForStride(img, orig, convLayer, int(32/2**i), optimizer))
    plt.figure(num=None, figsize=(10, 3), dpi=350); plt.plot(losses); plt.grid(True); plt.show()
    self.displayConvOutputs(img)
Net.maximizeForLayer = maximizeForLayer

Before experimenting, just so you know, I have another post that includes extra examples of the experiments below. I encourage you to take a look at that, so that you can compare them side-by-side, and follow the predictions later on.

Experiment

Let's null out the last layer first:

In [519]:
net.maximizeForLayer(5)
Category: auto
Original:
Layer: 5, stride: 32
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 5, stride: 16
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 5, stride: 8
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 5, stride: 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 5, stride: 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 5, stride: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Conv layer: 1
Conv layer: 2
Conv layer: 3
Conv layer: 4
Conv layer: 5

The kernel output seems to be very, very uniform throughout. Seems like it has succeeded. The interesting thing to notice is that layers 1 through 4 seems to still be functioning perfectly fine, and still seems to extract good information out, but for some weird reason, we managed to tweak the image just enough that its entirely confused. Also notice that the difference to the original image is very minimal, and that it's fairly easy and fast for the Adam to train the image coming in (the losss are really sharp spikes).

How about nulling out the 4th layer?

In [521]:
net.maximizeForLayer(4)
Category: deer
Original:
Layer: 4, stride: 32
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 4, stride: 16
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 4, stride: 8
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 4, stride: 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 4, stride: 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 4, stride: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Conv layer: 1
Conv layer: 2
Conv layer: 3
Conv layer: 4
Conv layer: 5

There're 2 phenomenons worth noting here. First, Adam can't really push layer 4 down to absolute zero (but it can with stride >= 2). Second, layer 5 sort of enters this really, really weird state where pretty much everything lights up. This is again, because of the migrating-from-edge phenomenon we saw in the previous post, when the actual edge being set to zero is the only source of gradient. We may generalize this so that we can generate effectively sharp noise images to make the next layer take on a uniform value.

Why is that what we might want to do? Well, from the classes' probabilities, it seems like our network likes frogs a whole lot. Why does it do that? Well, because we have effectively nulled out the 4th layer and the 5th layer takes on uniform, constant values, the fully connected layers behind that kinda doesn't have any information to work with, so it falls back on a single-class default policy, which in our case, is frogs. We may want to know this information to defend against other networks who are actively trying to fool us. You can imagine that another person comes in with a normal image, figured out how to combine that with one of this nulling images, then they can effectively force the network to a category every time, and so they might exploit that as an attack vector.

In [522]:
net.maximizeForLayer(3)
Category: ship
Original:
Layer: 3, stride: 32
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 3, stride: 16
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 3, stride: 8
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 3, stride: 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 3, stride: 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 3, stride: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Conv layer: 1
Conv layer: 2
Conv layer: 3
Conv layer: 4
Conv layer: 5

Layer 3 is pretty much the same as layer 4. Still attracted to frogs and still force layers behind to find gradients from the edge. But this time, it's noticably harder and takes longer to train when stride is 1 or 2. It's also really surprising that almost all of layer 3 is null now. That wasn't the case with layer 4. It could be that it's just the fact that convolutions don't typically navigate in very complex terrains, unlike fully connected ones and other fancier layers. The convolution from layer 3 sort of sets things in stone, and thus there's less wiggle room to null layer 4's output. In contrast, right before layer 3 are all sorts of layers like batch norm and dropout, which can introduce complex terrain, making it easy to null it out. This might also explain why layer 5 can be dampened so much.

In [523]:
net.maximizeForLayer(2)
Category: dog
Original:
Layer: 2, stride: 32
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 2, stride: 16
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 2, stride: 8
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 2, stride: 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 2, stride: 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 2, stride: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Conv layer: 1
Conv layer: 2
Conv layer: 3
Conv layer: 4
Conv layer: 5

The image needed to null layer 2 is pretty much low frequency white noise. This makes sense, because without the complex terrain that dropout and batch norm introduces, it sort of just gets locked to simple images only.

The network also doesn't seem to like frogs that much any more, I guess because this isn't a smooth terrain chunk right before the fully connected layers, the final layer can take on a multi-class default policy. However, for the same reason layer 5 takes on a single-class default policy when we're nulling out layers 3 and 4, they (layers 3 and 4) must take on single-class default policies too. Also note that just as before, we have cut off any source of information about the image, the default policy should be really, really consistent. This is demonstrated very strongly at the extra results post.

Now that we have gained these insights, what properties should the image that nulls out the first layer has? First, it must be incredibly uniform, because again, convolutions don't introduce complex terrains, and because the output is uniform, the input should be uniform as well. Second, it should snap layer 5 into a multi-class default policy and it should be completely different from the policy when we were nulling layer 2. Third, because the information stream has been cut off, layers behind it has to migrate gradient from the edges. And finally, figuring out the image should be hard, and the loss is expected to plateau at a high spot, as there is pretty much no wiggle room at all. Let's see:

In [525]:
net.maximizeForLayer(1)
Category: auto
Original:
Layer: 1, stride: 32
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 1, stride: 16
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 1, stride: 8
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 1, stride: 4
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 1, stride: 2
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Layer: 1, stride: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Conv layer: 1
Conv layer: 2
Conv layer: 3
Conv layer: 4
Conv layer: 5

Yep, seems like it does behave like that. I don't know about you but I kinda have a feel for convolutional layers now, after all of that experimentation. Also another interesting phenomenon is that while the main image is uniformly white (as we have predicted), the closer you are to the edges, the darker it is. This may just be the network's way to deal with the extra padding we add in.

Some ideas that have sprung up on how to take these results further, to really see what they can do:

  • Seems like we can control deep layers (like 5 in our experiment) by just touching the input and the signal we give it doesn't get attenuated. How deep can we go?
  • Now that we know by messing around with the next-to-last-layer smooth terrain zone, we can force the network into its single-class default policy, can we now procedurally generate adversarial examples that will actually fool our network?
  • Can we do the above with just inputs and outputs, and no gradient information in between? If we can do that then pretty much all ML vision systems has a huge built-in security hole, so we can exploit that on existing systems.
  • We know that nulling the last layer doesn't change the image so much, yet it has a huge impact on classification accuracy, can we generate these images on-the-fly and train the network with that? This is sort of just a data augmentation technique, but it feels like this might be really, really effective, and if we do this repeatedly, we should have the capability to withstand against adversarial networks.
  • After doing the above, how will the adversarial examples now look like? I expect them to not be of high quality and can't defeat the network anymore, but may be the high dimensional terrain have its way to squeeze out the very unlikely adversarial images, and that this process can never end.
  • We might want to do an experiment where we pit humans against the deformed images, to further see where human's capabilities are at. Then, if we know that, then we might be able to compare them together and expect what behavior at what computational scale.
  • Why should the single-class default policy even exist at all? From my intuition, all default policies should be of multiple classes. So to resolve this, we can measure the mean and standard deviation of everything, to see if the final layer when freezing the 3rd or 4th layer gets extreme enough that forces all other classes down.

After this, I went ahead and train 3 more networks. One that maximizes the kernel output, one that minimizes the standard deviation and one that maximizes the standard deviation of the kernel output. Turns out, there aren't any more insight I can come up with. Std max looks like max, and std min looks like min. Affinity to a certain class is still common, and probability distributions are sometimes close to other scenarios (same multi-class default policy). Sometimes they're just random though. Anyway, check out those links if you're interested. Next time, we will try to implement some ideas above.