Date: June 25th, 2020
Table of contents:
Today, we will try to recreate Adam, a pretty badass optimizer, as it can deal with a lot of internal network instability. This blackbox post will include some analysis, so it can feel like it's a deep dive post too. Let's first import a bunch of thing and loads in MNIST. Everything is pretty standard:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision
import torch.optim as optim
import time
dl = torch.utils.data.DataLoader(datasets.MNIST("datasets/", download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=1000, shuffle=True)
dl_test = torch.utils.data.DataLoader(datasets.MNIST("datasets/", download=True, train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=100, shuffle=True)
imgs, labels = next(iter(dl))
print(f"Images shape: {imgs.shape}, labels shape: {labels.shape}")
print(f"Image range: ({imgs.min()}, {imgs.max()}), average: {imgs.mean()}, std: {imgs.std()}")
print(f"Labels: {labels[:15]}")
plt.figure(num=None, figsize=(10, 6), dpi=350)
for i in range(15): plt.subplot(1, 15, i+1); plt.imshow(imgs[i][0]); plt.axis("off")
Now let's think about this. Let's suppose that we have actually created Adam, but how can we differentiate it with other optimizers? While I was writing a previous post about basic blackboxes, I found out about something interesting while trying to recreate nn.CrossEntropyLoss. I tried to create other loss functions for categorical recognition. I did LinearLoss ($1-x$) and InverseLoss ($x^{-\alpha}$) just to demonstrate that we don't necessarily have to use negative log loss ($-log(x)$). The interesting bit is that $x^{-1}$ doesn't work using regular SGD, but $x^{-0.5}$ works fine. If I used Adam instead of SGD, then it can do both $x^{-1}$ and $x^{-0.5}$.
After observing this phenomenon, I then tried to explain why $-log(x)$ is still better than $x^{-0.5}$, and the TL;DR version is because $-log(x)$ works well with Softmax. Now that this problem came up, it seems like we can intentionally makes the life of the optimizers harder, and we can differentiate them that way. We should expect optimizers should struggle more as $\alpha$ increases. So, let's test it out.
I don't just copy over the code directly from the basics blackbox post. I take their core algorithm and wrap it so that it's a valid module:
class CrossEntropyLoss(nn.Module):
def __init__(self): super().__init__()
def forward(self, output, correctLabels):
nlX = -torch.log(torch.exp(output)/torch.exp(output).sum(1).unsqueeze(-1))
return nlX[range(len(correctLabels)), correctLabels].sum()/len(correctLabels)
class InverseLoss(nn.Module):
def __init__(self, exponent=-0.55): super().__init__(); self.exponent = exponent
def forward(self, output, correctLabels):
batchSize = len(correctLabels); output = torch.exp(output)/torch.exp(output).sum()
# the 0.0001 is to prevent division by 0
return (torch.pow(output[torch.arange(batchSize), correctLabels]+0.0001, self.exponent)-1).sum()/batchSize
class CustomNet1(nn.Module):
def __init__(self, lossFunction=nn.CrossEntropyLoss()):
super().__init__(); self.fc1 = nn.Linear(28 * 28, 100); self.fc2 = nn.Linear(100, 10)
self.relu = nn.ReLU(); self.lossFunction = lossFunction
def forward(self, x): return self.fc2(self.relu(self.fc1(x)))
def train(self, optimizer, epochs):
count, losses, accuracies = 0, [], []
for epoch in range(epochs):
for imgs, labels in dl:
batch_size = len(labels)
loss = self.lossFunction(self(imgs.view(-1, 28*28).cuda()), labels.cuda())
loss.backward(); optimizer.step(); optimizer.zero_grad(); count += 1
if count % 10 == 0:
losses.append(loss.item()); test_imgs, test_labels = next(iter(dl_test))
output = self(test_imgs.cuda().view(-1, 28*28))
accuracies.append((torch.argmax(output, dim=1) == test_labels.cuda()).sum())
return losses, accuracies
Let's now copy over the vanilla SGD optimizer from the blackbox basics post again:
class CustomSGD:
def __init__(self, parameters, lr=0.01): self.parameters = list(parameters); self.lr = lr
def zero_grad(self):
for param in self.parameters:
if param.grad is not None: param.grad.data.zero_()
def step(self):
with torch.no_grad():
for param in self.parameters: param -= self.lr * param.grad
Let's have a convenience function for plotting stuff:
def plot(losses, accuracies):
plt.figure(num=None, figsize=(10, 3), dpi=350)
plt.subplot(1, 2, 1); plt.grid(True); plt.plot(losses); plt.title("Loss")
plt.subplot(1, 2, 2); plt.grid(True); plt.plot(accuracies); plt.title("Accuracy (%)")
So with this insight, let's benchmark several scenarios:
Let's try it out with nothing strange, just to make sure everything still works as predicted:
net = CustomNet1(nn.CrossEntropyLoss()).cuda()
plot(*net.train(optim.Adam(net.parameters(), 0.01), 2))
And 1 more control case, again, because we're doing good science:
net = CustomNet1(CrossEntropyLoss()).cuda()
plot(*net.train(optim.Adam(net.parameters(), 0.01), 2))
Now let's bring on the challenge. At -0.5 level, SGD still works:
net = CustomNet1(InverseLoss(-0.5)).cuda()
plot(*net.train(CustomSGD(net.parameters(), 0.01), 2))
Now let's up the challenge a little bit:
net = CustomNet1(InverseLoss(-1)).cuda()
plot(*net.train(CustomSGD(net.parameters(), 0.01), 2))
As we can see, $x^{-1}$ really gives SGD a hard time.
net = CustomNet1(InverseLoss(-1)).cuda()
plot(*net.train(optim.SGD(net.parameters(), 0.01), 2))
Here, the built-in SGD has the same behavior.
net = CustomNet1(InverseLoss(-0.3)).cuda()
plot(*net.train(optim.SGD(net.parameters(), 0.01, momentum=0.9, nesterov=True), 2))
Interestingly, this is even worse than vanilla SGD, as it can't even deal with $x^{-0.5}$, but can still deal with an easier loss function of $x^{-0.3}$. I did not expect this at all, as Nesterov momentum is supposed to give SGD a pretty major boost, so I think this phenomenon deserves more attention.
Now let's give Adam a challenge:
net = CustomNet1(InverseLoss(-5.75)).cuda()
plot(*net.train(optim.Adam(net.parameters(), 0.01), 2))
Doing pretty well so far. Note the loss's scale. The loss's values are insanely high, on the order of $10^{21}$. Typically, backprop fails at extremely large scales, however, Adam can still learn it. And why $-5.75$ and not some other value? Well I experimented around and found that $-5.75$ is pretty much the upper limit for Adam. If I increase it to $-5.8$ or $6$, then Adam can't find a way through.
Let's test this on other optimizers as well, just to have a feeling of where they are at:
net = CustomNet1(InverseLoss(-5.75)).cuda()
plot(*net.train(optim.Adagrad(net.parameters(), 0.01), 2))
I experimented around and saw that Adagrad's limit is slightly below $-5.75$. With something like $-5$, Adagrad can reliably train. Let's just run it for completeness:
net = CustomNet1(InverseLoss(-5)).cuda()
plot(*net.train(optim.Adagrad(net.parameters(), 0.01), 2))
Interestingly, for some reason, Rprop actually managed to get reasonably good results for a much higher threshold at $-8.5$. At $9$ and above, it starts to become unstable and unreliable. This phenomenon deserves more attention.
net = CustomNet1(InverseLoss(-8.5)).cuda()
plot(*net.train(optim.Rprop(net.parameters(), 0.01), 2))
net = CustomNet1(InverseLoss(-6)).cuda()
plot(*net.train(optim.Adadelta(net.parameters(), 1), 2))
This is slightly better than Adagrad, as Adadelta and Adagrad is basically the same thing, with Adadelta adapts better over time, and Adagrad just decays the learning rate exponentially and will eventually freezes the network. We only train for 2 epochs tho, so the effect is not very noticable using our performance metric.
Also, it's interesting how Adam, Adagrad and Adadelta all surrounds the $-5.7$ value. Although they share similar characteristics, their actual algorithms is quite different, so may be we have discovered a new invariant and it may be worthwhile to investigate this further.
Now that we can differentiate between optimizers, we can be sure the optimizer we will create is actually Adam, so let's jump right in. Let's look at some interesting metrics and observations from the original paper:
This is Adam's overall performance. On regular training, where the updates are not sparse, like in image recognition, Adam works really, really well, and is pretty much on par with SGD with Nesterov (original paper if you're interested). Adagrad doesn't do so well though, because AdaGrad is built for sparse datasets, so convergence rate is slower. Over on the IMDB Bag of Words model, the opposite thing happens, where Adagrad converges pretty rapidly, and SGD with Nesterov does poorly.
Before Adam, SGD with Nesterov is the hot thing that people are using for non-sparse stuff, and Adagrad for sparse stuff. So for Adam to just swoops in out of the blue and be the best of both worlds is... disgusting. How can there be competition left for other optimizers?
This is the algorithm itself:
Okay, so what the hell is going on here? Why are there so many variables and terms? How the hell do they all work together, how can they work together, what's the intuition and how can someone like you or me can come up with this on our own?
First notice that $\beta_1$ and $\beta_2$ are always inside the $[0, 1)$ range. Also notice that $\beta_1$ is used in the $m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t$ equation. More specifically, it's used as a pair of $\beta_1$ and $1-\beta_1$. This looks exactly like how you would do linear interpolation. For example, if you have the numbers 200 and 500, then the middle number that's 30% of 200 and 70% of 500 will be $200\cdot0.3+500\cdot0.7=410$, where $\beta = 0.3, 1-\beta = 0.7$. If $\beta=0$, meaning 0% of 200 and 100% of 500, the middle number will be 500, which makes sense, as we're taking 100% of 500.
So, this is probably a linear interpolation. This also means that $\beta_1$ and $\beta_2$ are unitless quantities, because it's just the ratio number, and that $m_t, m_{t-1}, g_t$ all should have the same unit, what ever that is. Playing the same game, we can conclude that $v_t, v_{t-1}, g_t^2$ should have the same unit.
We also have the equations $\hat{m}_t\leftarrow m_t/(1-\beta_1^t)$ and $\hat{v}_t\leftarrow v_t/(1-\beta_2^t)$. Recall that $\beta_1$ and $\beta_2$ are unitless, so we can also conclude that $\hat{m}_t$ has the same unit as $m_t$ and likewise for $\hat{v}_t$ and $v_t$. Let's now analyze how $\beta_1$ and $\beta_2$ will behave:
t = torch.linspace(1, 1000); plt.figure(num=None, figsize=(10, 6), dpi=350)
plt.plot(t, 0.9**t, label=r'$\beta_1$'); plt.plot(t, 0.999**t, label=r'$\beta_2$')
plt.legend(); plt.grid(); plt.xlabel("Iterations"); plt.ylabel("Value"); pass
So, it seems like with the default values, $\beta_1$ decays much, much quicker than $beta_2$, and that even after 1000 iterations, $\beta_2$ still has a large value. As you can see, without domain knowledge, we can still figure out a lot from the equations alone. Now, I'll try to explain the rather more complicated bits.
The convention is that $\theta$ denotes a 1d array of all weights. In most implementations, it's not as simple as that, as we have layers like nn.Linear, nn.Conv2d, etc. with multidimensional weights. But in most papers, people just squish everything down to a 1d array, because you can do that pretty easily, and the geometric relationships of a multidimensional weight tensor is regarded as not important.
$\nabla_\theta f_t(\theta_{t-1})$ is just the gradient of the loss function with respect to the weights. It's important to note that $f_t$ is evaluated at time $t$, but uses the weights $\theta_{t-1}$ at time $t-1$. This is also a 1d array, as there is only 1 loss value, and the weights are arranged into a 1d array.
The algorithm also mentions about moments. "Moment" is actually a concept in mathematics. You can read more here), but in simple terms, the first moment is the mean, and the second moment is the variance. Normally, the mean is defined as $\mu = \sum^n_{i=1}x_i$ and the variance as $Var=\sigma^2=\sum^n_{i=1}(x_i-\mu)^2$. Here, they got rid of the summation term over all weights, because the goal is to provide every weight with its own personalized learning rate, so that different sections of the network can learn at different rates, and really adapt to the weight landscape. They also got rid of the $\mu$ term when calculating the variance, an "uncentered" variance, you might say. Because the moments used are not standard, we can't really apply familiar statistical methods on it, so we should do our own analysis.
t = torch.linspace(1, 1000); beta1 = 0.9**t; beta2 = 0.999**t; noise = torch.rand(len(t))*2 - 1
def graph(gradient, limit=None):
m = [0]; v = [0]
for i in range(len(t)):
m.append(beta1[0] * m[-1] + (1-beta1[0]) * gradient[i])
v.append(beta2[0] * v[-1] + (1-beta2[0]) * gradient[i]**2)
m = torch.Tensor(m[1:]); v = torch.Tensor(v[1:])
mHat = m/(1-beta1); vHat = v/(1-beta2); r = mHat/(vHat**0.5+1e-8)
plt.figure(figsize=(10, 6), dpi=350);
if limit is not None: plt.ylim(limit)
plt.plot(t, beta1, label=r'$\beta_1$'); plt.plot(t, beta2, label=r'$\beta_2$')
plt.plot(t, m, label=r'$m$'); plt.plot(t, v, label=r'$v$')
plt.plot(t[::2], mHat[::2], ".", label=r'$\hat{m}$')
plt.plot(t, vHat, label=r'$\hat{v}$'); plt.plot(t, r, label=r'$\hat{m}/\sqrt{\hat{v}}$')
plt.plot(t[::7], gradient[::7], "o", label=r'$\nabla$')
plt.legend(); plt.grid(); plt.xlabel("Iterations"); plt.ylabel("Value"); plt.show()
The function above expects a gradient to be passed in. It will then calculate the moments ($m_t, v_t$) and the bias-corrected moments ($\hat{m}_t, \hat{v}_t$), then figures out the official learning rate $\hat{m}_t/\sqrt{\hat{v}_t}$, then graphs all of that results out, so we can just focus on what's important. So let's look over how Adam behaves:
graph(torch.ones(100))
Okay, interesting, let's see a few more:
This is exactly the same as the above, but some noise is introduced, to verify that Adam can take a little bit of it and be fine. Also notice that $m\approx\hat{m}\;\forall t>20$, and that $m$ and $\nabla$ aren't really close to each other, but $E[m]\approx E[\hat{m}]\;\forall t>200$. This should feel familiar and comfortable, as $m$ is the first moment.
graph(1 + noise * 0.2)
The idea is just to see how slightly decreasing gradients will act, but turns out, it kinda looks like the $\nabla=1$ case.
graph(t**-0.2)
This again, is to demonstrate that the generic shape is still sort of preserved and can tolerate some noise.
graph(t**-0.2+noise*0.1)
This is to demonstrate that if the noise ball we converges to dips below 0 at some point (10% of the time), which is supposed to happen quite a while later. Because only 10% dips below the 0 line, nothing interesting is happening, and the first moments ($m_t, \hat{m}_t$) don't turn negative.
graph(t**-0.2+noise*0.4); (t**-0.2+noise*0.4 > 0).sum()
This demonstrates that the learning rate converges to 0 if the gradient converges to 0, reinforcing a pattern that seems obvious from the equations.
graph(np.log(1000) - torch.log(t), limit=(-0.2, 3))
Here, we're doing the same thing just as above, but we sometimes alternate signs, to simulate how the gradients changes sign all the time while training, but the magnitude is monotonically decreasing. This serves to demonstrate that the final learning rate also converges to 0. Also, see that the green line is sort of kicked up quite violently at the beginning. This is due to $m_0/(1-\beta_1^1)=10m_0$ on the first time step, and $m_1/(1-\beta_1^2)=5.26m_1$ kicking up whatever value is in the beginning.
If the first gradient was actually negative, then this will kick it to quite a low number. This seems like a behavior initially not built in, but later crafted by Adam's authors, and it's also quite sensible, as this serves sort of as the initial acceleration phase, which will immediately sets the moments to an appropriate value right away, then as time moves on, the variance (controlled by $\beta_2$) typically decreases, and this actually simulates the annealing behavior, which is pretty cool how they managed to slide that in here.
graph((np.log(1000) - torch.log(t)) * noise, limit=(-0.5, 1.5))
((np.log(1000) - torch.log(t)) * noise > 0).sum()
How about we did the same thing, but now, squishes the gradient really, really low? Will the learning rate immediately converges to 0?
Looks like it takes quite a long time to converge to 0. This may just be the fact that the learning rate always starts out as 1, and with high momentum values, it's sort of expected that it will take a long time to go down. Nothing interesting here, but it's comforting to know that it will always starts out like that, so we can reason about it somewhat.
graph((np.log(1000) - torch.log(t)) * noise * 0.02)
This just looks cool, so I figure I'd throw it in the mix too.
graph(3 - torch.log(t), limit=(-5, 2))
This is a pure noise environment, and should be typical of trained networks. As we can see, as long as the gradient doesn't decrease, the learning rate will always be big and won't converge. The learning rate still starts out at 1 though, so it's still trying to drag itself over to the middle.
graph(noise)
Now that we have a feel for Adam, can we see potential pros and cons of it? Let's talk about them.
Pros:
Cons:
Additional possible insights:
Okay, let's just create this shall we?
class CustomAdam:
def __init__(self, parameters, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
self.parameters = list(parameters) # take the parameters out of the generator
self.lr = lr; self.beta1 = beta1; self.beta2 = beta2; self.t = 0; self.eps = eps
self.m = []; self.v = []; cuda = self.parameters[0].is_cuda
for param in self.parameters:
self.m.append(torch.zeros(param.shape).cuda() if cuda else torch.zeros(param.shape))
self.v.append(torch.zeros(param.shape).cuda() if cuda else torch.zeros(param.shape))
def zero_grad(self):
for param in self.parameters:
if param.grad is not None: param.grad.data.zero_()
def step(self):
self.t += 1
with torch.no_grad():
for i in range(len(self.parameters)):
p = self.parameters[i]
if p.grad is None: continue
self.m[i] = self.m[i]*self.beta1 + (1-self.beta1)*p.grad
self.v[i] = self.v[i]*self.beta2 + (1-self.beta2)*p.grad**2
mHat = self.m[i] / (1-self.beta1**self.t)
vHat = self.v[i] / (1-self.beta2**self.t)
p.sub_(self.lr * mHat/(torch.sqrt(vHat) + self.eps))
net = CustomNet1(CrossEntropyLoss()).cuda()
plot(*net.train(CustomAdam(net.parameters(), 0.01), 2))
net = CustomNet1(InverseLoss(-5.5)).cuda()
plot(*net.train(CustomAdam(net.parameters(), 0.01), 2))
Yep, seems like the thing we created is truely Adam.