Unboxing the blackbox: Transpose convolution, custom stride and padding

Table of contents:

Last time, we saw how to create nn.Conv2d from scratch. I hope it's easy to understand, because if you get that, this post is pretty much the same thing, with minor cosmetic and functional changes.

Let's quickly import everything we need right away:

In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision
import torch.optim as optim
import time

We're also gonna copy large chunks of the code we have written over here:

In [5]:
identity = torch.Tensor([[0, 0, 0],
                         [0, 1, 0],
                         [0, 0, 0]]).cuda()
emboss = torch.Tensor([[-2, -1, 0],
                       [-1, 1, 1],
                       [0, 1, 2]]).cuda()
sharpen = torch.Tensor([[0, -1, 0],
                        [-1, 5, -1],
                        [0, -1, 0]]).cuda()
edgeDetect = torch.Tensor([[0, 1, 0],
                           [1, -4, 1],
                           [0, 1, 0]]).cuda()
In [7]:
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        kSqrt = 1/np.sqrt(in_channels*kernel_size**2)
        self.kernel = nn.Parameter(torch.rand(out_channels, in_channels, kernel_size, kernel_size)*2*kSqrt-kSqrt)
        self.bias = nn.Parameter(torch.rand(out_channels).unsqueeze(-1).unsqueeze(-1)*2*kSqrt-kSqrt)
    def forward(self, imgs):
        return conv4(imgs, self.kernel) + self.bias
In [9]:
def advancedKernel(inKernel):
    kernel = torch.cuda.FloatTensor(3, 3, 3, 3).fill_(0)
    kernel[0, 0] = inKernel # in red to out red channel
    kernel[1, 1] = inKernel # in green to out green channel
    kernel[2, 2] = inKernel # in blue to out blue channel
    return kernel

Custom stride

Then, to create stride effect, it's extremely simple, really. The only thing that's changed from conv4() is the last row, where we return the transformed tensor. We simply slice with the correct stride amount:

In [462]:
def conv5(imgs, kernel, stride=[1, 1]):
    if type(stride) == int: stride = [stride, stride]
    imgs = imgs.permute(0, 2, 3, 1)
    kernel = kernel.permute(2, 3, 1, 0)
    ks, inChannels, outChannels = kernel.shape[0], kernel.shape[2], kernel.shape[3]
    padding = ks - 1
    samples, height, width, _ = imgs.shape
    transformed = torch.cuda.FloatTensor(samples, height-padding, width-padding, outChannels).zero_()
    for yKernel in range(ks):
        for xKernel in range(ks):
            transformed += (imgs @ kernel[yKernel, xKernel])[:, yKernel:height-padding+yKernel, xKernel:width-padding+xKernel, :]
    return transformed[:, ::stride[0], ::stride[1], :].permute(0, 3, 1, 2)

Let's test things out just to make sure. Here's our beloved Rick Sanchez again:

In [719]:
rick = torch.Tensor(np.asarray(Image.open('rick_sanchez.jpg'))).permute(2, 0, 1).unsqueeze(0).cuda()
plt.figure(num=None, figsize=(10, 6), dpi=350); plt.imshow(rick[0].permute(1, 2, 0).cpu()/256)
rick.shape
Out[719]:
torch.Size([1, 3, 860, 1600])
In [463]:
transformed1 = conv5(rick, advancedKernel(edgeDetect), stride=2)[0].permute(1, 2, 0)
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(transformed1.cpu()/256); pass
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Let's see the built in implementation:

In [464]:
transformed2 = F.conv2d(rick, advancedKernel(edgeDetect), stride=2)[0].permute(1, 2, 0)
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(transformed2.cpu()/256); pass
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Note the dimension of the image. It has actually been reduced in half. Let's see if they are exactly identical:

In [433]:
torch.abs(transformed1 - transformed2).sum()
Out[433]:
tensor(0., device='cuda:0')

Yep, they're the same. If you have a keen eye, you might notice that we might gain some performance when stride value is large by reducing the image size using the slicing stuff first, then apply the kernel over it. I've tested this and we don't seem to gain any performance. Let's benchmark the built in function.

This is quite a tricky and involved process, because computer nowadays are pretty crazy fast and clever, and if they feel like they can cache stuff, they will do, and that will mess up our results. Just graphing raw data won't cut it, and if we don't do special stuff, the timing measurements will not be reproducible. We first define a timing function. Notice how we have a sleep period between each measurements to "cool down" which will make sure the kernel (fyi, "kernel" here is a software system, not the 4d tensor) is ready to take the load:

In [416]:
def timeit(function: callable, times: int=10):
    result = []
    for i in range(times):
        before = time.time()
        function(i)
        result.append(time.time()-before)
        time.sleep(0.1)
    return torch.Tensor(result)

Then we can create a function to graph stuff quickly. Notice that we have added kernelNoise to our kernel. This is to prevent PyTorch to have caches of previous results. We also eliminate the first sample, because right then, the kernel is sort of starting up, so pretty much all first results will be way slower than the rest:

In [417]:
def performance(function: callable, kernel, start=1):
    plt.figure(num=None, figsize=(10, 6), dpi=350)
    kernelNoise = torch.rand(10, 3, 3, 3, 3).cuda()*0.001
    plt.plot((timeit(lambda i: function(rick, kernel+kernelNoise[i], stride=1))*1000)[start:], "-r")
    plt.plot((timeit(lambda i: function(rick, kernel+kernelNoise[i], stride=2))*1000)[start:], "-g")
    plt.plot((timeit(lambda i: function(rick, kernel+kernelNoise[i], stride=3))*1000)[start:], "-b")
    plt.plot((timeit(lambda i: function(rick, kernel+kernelNoise[i], stride=4))*1000)[start:], "-y")
    plt.plot((timeit(lambda i: function(rick, kernel+kernelNoise[i], stride=5))*1000)[start:], "-c")
    plt.legend(["Stride 1", "Stride 2", "Stride 3", "Stride 4", "Stride 5"])
    plt.xlabel("Trial")
    plt.ylabel("Time (ms)")
    plt.grid(True)

Let's measure the performance of the built in implementation:

In [439]:
performance(F.conv2d, advancedKernel(edgeDetect))

And then our implementation:

In [437]:
performance(conv5, advancedKernel(edgeDetect))

The built in implementation's performance doesn't seem to increase when the stride gets larger, so it sort of tells me that they are doing the exact same thing as we do, which is just do a regular convolution first, then just slice away the unnecessary parts.

Custom padding

I hope it is stupid obvious how to implement this. You just need to initialize a slightly bigger tensor, then copies the existing tensor over, and you're done. But just for completeness, let's do this. Because this is stupid simple, I'll just use PyTorch's nn.ConstantPad2d function:

In [455]:
rick.shape, nn.ConstantPad2d([1, 1, 1, 1], 0)(rick).shape
Out[455]:
(torch.Size([1, 3, 860, 1600]), torch.Size([1, 3, 862, 1602]))

Here's the implementation. I changed bits here and there, to make it more generalizable to 2 dimension. conv5() sort of expects the kernel coming in to be a square:

In [664]:
def conv6(imgs, kernel, stride=[1, 1], padding=[0, 0, 0, 0]):
    sY, sX = (stride, stride) if type(stride) == int else stride
    if type(padding) == int: padding = [padding] * 4
    imgs = imgs.permute(0, 2, 3, 1)
    kernel = kernel.permute(2, 3, 1, 0)
    ksY, ksX, inChannels, outChannels = kernel.shape
    pY, pX = ksY - 1, ksX - 1 # internal padding, to be used at different places
    samples, height, width, _ = imgs.shape
    transformed = torch.cuda.FloatTensor(samples, height-pY, width-pX, outChannels).zero_()
    for yK in range(ksY): # yKernel
        for xK in range(ksX): # xKernel
            transformed += (imgs @ kernel[yK, xK])[:, yK:height-pY+yK, xK:width-pX+xK, :]
    return nn.ConstantPad2d(padding, 0)(transformed[:, ::sY, ::sX, :].permute(0, 3, 1, 2))

Let's test it out:

In [702]:
rick.shape, conv6(rick, advancedKernel(edgeDetect), padding=1).shape
Out[702]:
(torch.Size([1, 3, 860, 1600]), torch.Size([1, 3, 860, 1600]))

Yep, it's working.

Transpose convolution

Let's review what they are first. This is another of the video from that udacity course:

In [748]:
from IPython.lib.display import YouTubeVideo
YouTubeVideo('hnnLAC1Q0zg', width=800, height=450)
Out[748]:

The idea is also pretty similar to the normal convolution case. Pretty much the same thing happens everywhere, except for tiny kernel changes, and the transformed indexing schema is a little different. The general idea here is to still multiply the original image with an item of the kernel, but this time instead of selecting out the multiplied image, we select out the target image to target specific spots:

In [641]:
def tConv6(imgs, kernel, stride=[1, 1], padding=[0, 0]):
    sY, sX = (stride, stride) if type(stride) == int else stride
    if type(padding) == int: padding = [padding] * 2
    imgs = imgs.permute(0, 2, 3, 1)
    kernel = kernel.permute(2, 3, 0, 1) # slightly different from conv6()
    ksY, ksX, inChannels, outChannels = kernel.shape
    pY, pX = ksY - 1, ksX - 1 # internal padding, to be used at different places
    samples, height, width, _ = imgs.shape
    transformed = torch.cuda.FloatTensor(samples, (height-1)*sY+1+pY, (width-1)*sX+1+pX, outChannels).zero_()
    for yK in range(ksY): # yKernel
        for xK in range(ksX): # xKernel
            transformed[:, yK:yK+height*sY:sY, xK:xK+width*sX:sX, :] += (imgs @ kernel[yK, xK])
    return transformed.permute(0, 3, 1, 2)[:, :, padding[0]:, padding[1]:]

I hope it looks familiar to you. If not, you can always try and rerunning this notebook. Let's make sure we're binary compatible with the built in implementation:

In [703]:
kernel = advancedKernel(edgeDetect)
a = tConv6(rick, kernel, stride=2)
b = F.conv_transpose2d(rick, kernel, stride=2)
torch.abs(a - b).sum()
Out[703]:
tensor(0., device='cuda:0')

Yep, it's working as expected. Let's have fun with this and see what happens when you pass images through it. Let's scale the image down using a normal convolution with stride 2 first:

In [746]:
embossedRick = conv6(rick, advancedKernel(emboss), stride=2, padding=0)
edgyRick = conv6(rick, advancedKernel(edgeDetect), stride=2, padding=0)
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(edgyRick[0].permute(1, 2, 0).cpu()/256)
edgyRick.shape
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[746]:
torch.Size([1, 3, 429, 799])

This edgy rick we have seen a lot already, so nothing new here. Also let's do an embossed version, for later. Let's apply the transpose convolution:

In [714]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
a = tConv6(edgyRick, advancedKernel(edgeDetect), stride=2)
plt.imshow(a[0].permute(1, 2, 0).cpu()/256)
a.shape
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[714]:
torch.Size([1, 3, 859, 1599])

Hmmmm, interesting. How about let's pass this through another transpose layer, but this time make the stride 1?

In [713]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
a = tConv6(edgyRick, advancedKernel(edgeDetect), stride=2)
a = tConv6(a, advancedKernel(edgeDetect), stride=1)
plt.imshow(a[0].permute(1, 2, 0).cpu()/256)
a.shape
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[713]:
torch.Size([1, 3, 861, 1601])

How about another layer, also with stride 1?

In [716]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
a = tConv6(edgyRick, advancedKernel(edgeDetect), stride=2)
a = tConv6(a, advancedKernel(edgeDetect), stride=1)
a = tConv6(a, advancedKernel(edgeDetect), stride=1)
plt.imshow(a[0].permute(1, 2, 0).cpu()/256)
a.shape
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[716]:
torch.Size([1, 3, 863, 1603])

You can see how this can help with upsampling an image. As the image is passed through more and more transpose convolutions, the influence of the kernel widens, until it has control over every bit in an image. It really does have a feeling of upsampling, while the regular convolution has a feel of distilling information down.

There are definitely weird behaviors though:

In [738]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(tConv6(embossedRick, advancedKernel(edgeDetect), stride=2)[0].permute(1, 2, 0).cpu()/256)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[738]:
<matplotlib.image.AxesImage at 0x7fdb1faa58d0>
In [737]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(tConv6(embossedRick, advancedKernel(identity), stride=2)[0].permute(1, 2, 0).cpu()/256)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[737]:
<matplotlib.image.AxesImage at 0x7fdb1fb30990>
In [736]:
plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.imshow(tConv6(embossedRick, advancedKernel(emboss), stride=2)[0].permute(1, 2, 0).cpu()/256)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[736]:
<matplotlib.image.AxesImage at 0x7fdb1fbbca50>

There seems to be no meaning (that we care about) about what the transpose convolution actually does to an image though.

Benchmarks

Finally, let's benchmark these with the tools we developed earlier:

In [741]:
performance(tConv6, advancedKernel(edgeDetect))
In [742]:
performance(F.conv_transpose2d, advancedKernel(edgeDetect))

Our implementation is around 6 times slower than the state of the art. Considering how long it take to do convolutions the vanilla way, I'd say this is a success.

I still find it very strange that the performance across strides are pretty much the same, for the built in function. I mean, they surely must have thought about how to utilize lower stride numbers?

Now what's left is to build a module around tConv6(), but that is essentially identical to the normal convolution case:

In [747]:
class ConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        kSqrt = 1/np.sqrt(in_channels*kernel_size**2)
        self.kernel = nn.Parameter(torch.rand(in_channels, out_channels, kernel_size, kernel_size)*2*kSqrt-kSqrt)
        self.bias = nn.Parameter(torch.rand(out_channels).unsqueeze(-1).unsqueeze(-1)*2*kSqrt-kSqrt)
    def forward(self, imgs):
        return tConv6(imgs, self.kernel) + self.bias

Pretty much all the changes are the swapping of the in_channels and out_channels and replacing the normal convolution for the transpose one. Pretty straightforward. I have not tested the code above though, but I'm confident it'll do as expected.

That's all I have for today. Hope you guys learned something new and see you in the next post.