Learning simple functions, network exhaustion point

Table of contents:

In the last post, I left off a lot of potential, and it really, really bugs me to not do it right away. I mean, it's quite hard to just stop implementing these ideas you know, they taste so damn good.

So what I really wanted to do right away is to at least have some metrics into network capability, so we're not aiming blind when doing these things. The splitting ReLU output into different segments is what really, really pushed me. So let's get the party started:

In [336]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
In [337]:
class FunctionDataset(Dataset):
    def __init__(self, function: callable, start: float=-5, stop: float=5, samples: int=300):
        self.function = function
        self.start = start
        self.stop = stop
        self.samples = samples
    def __len__(self):
        return self.samples
    def __getitem__(self, index):
        x = index/self.samples * (self.stop - self.start) + self.start
        return x, self.function(x)

We have some new function definitions, just to declare these upfront

In [338]:
bs = 1280
expF = lambda x: torch.exp(x)
expNF = lambda x: torch.exp(-x)
logF = lambda x: torch.log(x)
invF = lambda x: 1 / x
linF = lambda x: 2 * x + 8
sinF = lambda x: torch.sin(x)
stepF = lambda x: x > 0
expDl = DataLoader(FunctionDataset(lambda x: np.exp(x), samples=10000), batch_size=bs)
expNDl = DataLoader(FunctionDataset(lambda x: np.exp(-x), samples=10000), batch_size=bs)
exp7Dl = DataLoader(FunctionDataset(lambda x: np.exp(x), samples=10000, stop=7), batch_size=bs)
logDl = DataLoader(FunctionDataset(lambda x: np.log(x), samples=10000), batch_size=bs)
invDl = DataLoader(FunctionDataset(invF, samples=10001), batch_size=bs)
linDl = DataLoader(FunctionDataset(linF, samples=10000), batch_size=bs)
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), samples=10000), batch_size=bs)
stepDl = DataLoader(FunctionDataset(stepF, samples=10000), batch_size=bs)

Now this is the fun part. Most of the code is the same as last time, but notice the plot function (the new feature is also on by default):

In [485]:
# simple (#batch, 1) to (#batch, 1)
class NN(nn.Module):
    def __init__(self, hiddenDim=10, hiddenLayers=2, dropout_p=0, activation=nn.ReLU()):
        super().__init__()
        self.fc_begin = nn.Linear(1, hiddenDim)
        self.fc1 = nn.Linear(hiddenDim, hiddenDim)
        self.fc2 = nn.Linear(hiddenDim, hiddenDim)
        self.fc3 = nn.Linear(hiddenDim, hiddenDim)
        self.fc4 = nn.Linear(hiddenDim, hiddenDim)
        self.fc_end = nn.Linear(hiddenDim, 1)
        self.activation = activation
        self.dropout = nn.Dropout(dropout_p)
        self.totalLosses = []
        self.hiddenLayers = hiddenLayers
        pass
    def forward(self, x):
        x = self.dropout(self.activation(self.fc_begin(x)))
        # really ad-hoc way of doing this, but pytorch doesn't allow me to bunch this up into a list
        if self.hiddenLayers >= 1: x = self.dropout(self.activation(self.fc1(x)))
        if self.hiddenLayers >= 2: x = self.dropout(self.activation(self.fc2(x)))
        if self.hiddenLayers >= 3: x = self.dropout(self.activation(self.fc3(x)))
        if self.hiddenLayers >= 4: x = self.dropout(self.activation(self.fc4(x)))
        x = self.fc_end(x)
        return x
    
    def train(self, dl, lossFunction=nn.MSELoss(), optimizer=None, lr=0.01, epochs=500):
        if optimizer == None:
            optimizer = optim.Adam(self.parameters(), lr=lr)
        for epoch in range(epochs):
            totalLoss = 0
            for x, y in dl:
                optimizer.zero_grad()
                x = x.view(-1, 1).float().cuda()
                output = self.forward(x)
                loss = lossFunction(output, y.view(-1, 1).float().cuda())
                loss.backward()
                totalLoss += loss.item()
                optimizer.step()
            totalLoss /= dl.batch_size
            self.totalLosses.append(totalLoss)
    def plot(self, x, function: callable=None, partitionSegments: bool=True):
        plt.figure(num=None, figsize=(10, 3), dpi=350)
        x = x.view(-1, 1)
        if function != None:
            plt.plot(x.cpu(), function(x).cpu(), "--")
        if partitionSegments:
            x, colors = x.cuda(), ["r", "g", "b", "c", "k"]
            last, segmentNumber, epsilon = 1, 0, 1e-3#/(x[1]-x[0])/30
            y = self(x).detach().view(-1)
            x = x.view(-1)
            gradients = (y - y.roll(1))/(x - x.roll(1))
            transition = torch.abs(gradients - gradients.roll(1)) > epsilon
            for step in range(2, len(x)):
                if transition[step]:
                    color = colors[segmentNumber%len(colors)]
                    if segmentNumber%len(colors) == 1:
                        plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color, lw=4)
                    plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color)
                    segmentNumber += 1
                    last = step
            else:
                color = colors[segmentNumber%len(colors)]
                if segmentNumber%len(colors) == 1:
                    plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color, lw=4)
                plt.plot(x[last-1:step].cpu(), y[last-1:step].cpu(), "-"+color)
        else:
            plt.plot(x.squeeze().cpu(), self(x.cuda()).detach().cpu().squeeze(), ".")
        plt.legend(["Real", "Learned"] if function != None else ["Learned"])
        plt.show()
    def plotLosses(self, begin=0, end=0):
        plt.figure(num=None, figsize=(10, 3), dpi=350)
        if end == 0:
            end = len(self.totalLosses)
        plt.plot(range(len(self.totalLosses))[begin:end], self.totalLosses[begin:end])
        plt.legend(["Loss"])
        plt.show()

Network dynamics

It's pretty complicated, and it's not particularly accurate and can glitches out if you make the parameters too extreme (because of floating point numbers accuracy limit), but check this out:

In [508]:
net = NN().cuda()
net.plot(torch.linspace(-10, 10, 300))

I made it so that it cycles the colors every 5 segments, and all green segments is bold, just to make it clear the density of the colors in a particular interval. Let's train and see how it fares. We're now considering the $(-2, 5)$ interval instead.

In [509]:
interval = torch.linspace(-2, 5, 300)
for i in range(4):
    net.plot(interval)
    net.train(expDl, epochs=30)
    print(f"Epoch: {i*30+30}")
for i in range(3):
    net.plot(interval)
    net.train(expDl, epochs=60, lr=0.001)
    print(f"Epoch: {i*60+60+120}")
net.plot(interval, expF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 180
Epoch: 240
Epoch: 300
In [511]:
net.plot(torch.linspace(-10, 10, 300))
net.plot(torch.linspace(-20, 20, 300))

Damn, this is dope. The algorithm has roughly segmented the interval nicely. On the interval $(0, 5)$ on the last 3 graphs, when the network has stopped training, there are around 4 green intervals, so we know they must be roughly right.

A weird phenomenon is happening though. If you look at epoch 30, just after it has snapped into place, there doesn't seem to be too many segments. Then as time moves on, it seems to "call in" network capability from the far right. This is not what I was expecting. I expect the capability to be centered around 0, as the initialization predictions tell us in the last post, and then spreads out slowly to accomodate places where the geometry is whack.

Let's try to see the snapping closely, to see what's going on:

In [490]:
net = NN().cuda()
interval = torch.linspace(-2, 7, 300)
for i in range(8):
    net.plot(interval)
    net.train(expDl, epochs=5)
    print(f"Epoch: {i*5+5}")
net.plot(interval)
net.plotLosses()
Epoch: 5
Epoch: 10
Epoch: 15
Epoch: 20
Epoch: 25
Epoch: 30
Epoch: 35
Epoch: 40
In [491]:
net.plot(torch.linspace(-10, 10, 600))
net.plot(torch.linspace(-20, 20, 600))

It's sort of all over the place. The part on the right regularly toggles between having lots of segments, and having no segments at all. May be this is just a consequence of these networks are truely like blackboxes. Still, I'm interested in whether we can bunch these segments together and exhaust the network.

Freezing the network using inverse function

Let's make it exhausts its capability around 0 for the inverse function, then introduce another inverse function with the singularity at $x=2$. Seems to be a good plan. Let's first train on the regular inverse first:

In [492]:
net = NN().cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(6):
    net.plot(interval)
    net.train(invDl, epochs=30)
    print(f"Epoch: {i*30+30}")
net.plot(interval, invF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 150
Epoch: 180

Then train on the shifted inverse:

In [493]:
inv2F = lambda x: 1 / (x - 2)
inv2Dl = DataLoader(FunctionDataset(inv2F, samples=10001), batch_size=bs)
interval = torch.linspace(-3, 7, 300)
for i in range(3):
    net.plot(interval)
    net.train(inv2Dl, epochs=15)
    print(f"Epoch: {i*30+30}")
for i in range(3):
    net.plot(interval)
    net.train(inv2Dl, epochs=100, lr=0.003)
    print(f"Epoch: {i*100+100+90}")
net.plot(interval, inv2F)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 190
Epoch: 290
Epoch: 390

Yeah it's trying very hard, but it is effectively frozen. If I want to freeze it even more, I can train it on the interval $(-3, 0)$ the function $y=-0.25$. That way, pretty much everywhere will be stuck at a constant value.

Exhausting the network using sine function

Let's also try out sine, first with $(-5, 5)$ interval, because it seems like that will exhaust everything quick, and we don't have to do this trick:

In [494]:
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), -5, 5, samples=10000), batch_size=bs)
net = NN().cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(6):
    net.plot(interval)
    net.train(sinDl, epochs=15)
    print(f"Epoch: {i*15+15}")
net.plot(interval, sinF)
net.plotLosses()
Epoch: 15
Epoch: 30
Epoch: 45
Epoch: 60
Epoch: 75
Epoch: 90

It seems to still be able to approximate it. And on the $(-10, 10)$ interval:

In [495]:
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), -10, 10, samples=10000), batch_size=bs)
net = NN().cuda()
interval = torch.linspace(-10, 10, 300)
for i in range(6):
    net.plot(interval)
    net.train(sinDl, epochs=15)
    print(f"Epoch: {i*15+15}")
net.plot(interval, sinF)
net.plotLosses()
Epoch: 15
Epoch: 30
Epoch: 45
Epoch: 60
Epoch: 75
Epoch: 90

Yep, just as expected, it couldn't do it. Can we do better with more hidden units per hidden layer?

In [496]:
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), -10, 10, samples=10000), batch_size=bs)
net = NN(hiddenDim=30).cuda()
interval = torch.linspace(-10, 10, 300)
for i in range(6):
    net.plot(interval)
    net.train(sinDl, epochs=30)
    print(f"Epoch: {i*30+30}")
net.plot(interval, sinF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 150
Epoch: 180

Yeah it's better, but the segments are still bunched together near the origin, and will have a hard time going over the humps to make things better. Will it help if we specifically train it on the $(5, 15)$ interval?

In [497]:
sinDl = DataLoader(FunctionDataset(lambda x: np.sin(x), 5, 15, samples=10000), batch_size=bs)
interval = torch.linspace(-10, 10, 300)
for i in range(6):
    net.plot(interval)
    net.train(sinDl, epochs=30)
    print(f"Epoch: {i*30+30}")
net.plot(interval, sinF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 150
Epoch: 180

Yeah that's sort of terrible. The spaces of all the weights doesn't have any well defined relationship with the input variable, so we can't really expect to learn a representation and retain that representation while we do other things. This also sort of reinforce that when the underlying distribution is entirely different, not even in the neighborhood, these networks really, really need to be trained from scratch.

Another important relevation is that initialization is more important than what people usually give credit for. In the past, I sort of lean more towards "this architecture will do better than that because of x, y and z", or "this can relay the gradient signal better because of x, y and z", and not so much about these fundamental concepts and problems.

A possible "hack" around this is to build networks with a sense of history and passage of time in them. I'm talking about RNNs, LSTMs and transformers. However, I feel like this is like we don't know how to make this black box work, so we're just gonna use a new, more complicated black box to deal with it. We are not thinking simple enough, and we're basically running this extensive search through the space of all possible network architectures and that's a losing battle right there. You've sort of failed even before you've begun. So although doing this may give better results, it doesn't give us any new insights.

Will the network recover from a distributional shift if we use leaky ReLU instead?

Leaky ReLU

The setup should be familiar. We train a network to approximate $e^x$, then try to approximate $e^{-x}$. This is the (nearly) exact same function from the previous post:

In [498]:
def transferLearning1(activation=nn.ReLU()):
    interval = torch.linspace(-5, 5, 300)
    net = NN(activation=activation).cuda()
    print("Before training")
    net.plot(interval)
    net.train(expDl, epochs=180)
    print("After training for 180 epochs")
    for i in range(6):
        net.plot(interval)
        net.train(expNDl, epochs=30)
        print(f"Epoch: {i*30+180+30}")
    net.plot(interval, expNF)
    net.plotLosses()

Let's do regular ReLU:

In [500]:
transferLearning1()
Before training
After training for 180 epochs
Epoch: 210
Epoch: 240
Epoch: 270
Epoch: 300
Epoch: 330
Epoch: 360

Frozen. Okay, we sort of expect that. Let's do leaky ReLU:

In [501]:
transferLearning1(nn.LeakyReLU())
Before training
After training for 180 epochs
Epoch: 210
Epoch: 240
Epoch: 270
Epoch: 300
Epoch: 330
Epoch: 360

Okay yeah, now the default behavior is being able to learn quickly. Let's test the extreme case of using 2 inverse functions like above. Training regular inverse:

In [502]:
net = NN(activation=nn.LeakyReLU()).cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(6):
    net.plot(interval)
    net.train(invDl, epochs=30)
    print(f"Epoch: {i*30+30}")
net.plot(interval, invF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 150
Epoch: 180

Then the shifted inverse:

In [503]:
inv2F = lambda x: 1 / (x - 2)
inv2Dl = DataLoader(FunctionDataset(inv2F, samples=10001), batch_size=bs)
interval = torch.linspace(-3, 7, 300)
for i in range(3):
    net.plot(interval)
    net.train(inv2Dl, epochs=15)
    print(f"Epoch: {i*30+30}")
for i in range(3):
    net.plot(interval)
    net.train(inv2Dl, epochs=100, lr=0.003)
    print(f"Epoch: {i*100+100+90}")
net.plot(interval, inv2F)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 190
Epoch: 290
Epoch: 390

Nope, still frozen, but at least it's trying more than regular ReLU. So there are limits to how much distributional shift can leaky ReLUs take on.

Why don't people always use leaky ReLU then? First, regular ReLU is sort of fast. Translated to machine code, it's literally just a jump opcode. Leaky ReLU on the other hand, must both jump and multiply opcode, so we should expect it to be slower. Also, if you're planning on making really high-end custom accelerator chips, like the one Tesla is making, you will want to simplify as much as possible, and throw out all the customizable stuff that drags the speed of your chip down.

ReLU also sort of introduces an amount of healthy "brain damage". This is like using dropout, but in a weird and permanent way. The general thinking is, if you have the weights initialization down, and do everything else roughly right, you can utilize this brain damage phenomenon as a replacement for dropout. Personally, this has never been too convincing, and I think that most people just tried ReLU on their project and it works out of the box, and they now don't want to change anything, or never have to deal with distributional shift.

Step function

Finally, let's just see how both ReLU and leaky ReLU deals with the step function. Seems like it should be very, very easy.

In [504]:
stepDl = DataLoader(FunctionDataset(lambda x: x > 0, samples=10000), batch_size=bs)
net = NN().cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(6):
    net.plot(interval)
    net.train(stepDl, epochs=30)
    print(f"Epoch: {i*30+30}")
net.plot(interval, stepF)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 120
Epoch: 150
Epoch: 180

Then train it with the step at 2:

In [505]:
step2F = lambda x: x > 2
step2Dl = DataLoader(FunctionDataset(step2F, samples=10000), batch_size=bs)
interval = torch.linspace(-3, 7, 300)
for i in range(3):
    net.plot(interval)
    net.train(step2Dl, epochs=15)
    print(f"Epoch: {i*30+30}")
for i in range(3):
    net.plot(interval)
    net.train(step2Dl, epochs=100, lr=0.003)
    print(f"Epoch: {i*100+100+90}")
net.plot(interval, step2F)
net.plotLosses()
Epoch: 30
Epoch: 60
Epoch: 90
Epoch: 190
Epoch: 290
Epoch: 390

Okay cool. Can we now freeze this by switching back and forth between these 2 functions?

In [506]:
net = NN().cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(4):
    net.plot(interval)
    net.train(stepDl, epochs=200)
    print(f"Step at 0 - {i}")
    net.plot(interval)
    net.train(step2Dl, epochs=200)
    print(f"Step at 2 - {i}")
net.plot(interval, step2F)
net.plotLosses()
Step at 0 - 0
Step at 2 - 0
Step at 0 - 1
Step at 2 - 1
Step at 0 - 2
Step at 2 - 2
Step at 0 - 3
Step at 2 - 3

Okay, so regular ReLU can adapt, if we don't intentionally try to freeze it. Still, we can see that all the step-at-2 graphs, the line is not quite straight up, while all step-at-0 graphs, they shoot straight. But I mean, data in the real world may be even harsher to the network than us. Let's see how the leaky ReLU fair:

In [507]:
net = NN(activation=nn.LeakyReLU()).cuda()
interval = torch.linspace(-5, 5, 300)
for i in range(4):
    net.plot(interval)
    net.train(stepDl, epochs=200)
    print(f"Step at 0 - {i}")
    net.plot(interval)
    net.train(step2Dl, epochs=200)
    print(f"Step at 2 - {i}")
net.plot(interval, step2F)
net.plotLosses()
Step at 0 - 0
Step at 2 - 0
Step at 0 - 1
Step at 2 - 1
Step at 0 - 2
Step at 2 - 2
Step at 0 - 3
Step at 2 - 3

Okay wow, that jumps around a lot! Still, it's more or less able to predict the function, and not-being-frozen capability I think is worth the extra noise.

Afterword

Again, there are many things we still haven't tried out yet, and many unanswered questions:

  • How can we resolve freezing in the inverse case? Can we solve it at all?
  • Leaky ReLU sort of have this jittering effect when doing the step function. Is this a sign of overfitting? Sort of makes sense because the model there has 30 hidden dimensions, and the segments sort of don't like to be near each other.
  • We can try to lower the hidden dimension, to see if it's truely overfitting, and if it is, this proves to be an exciting testbed for testing solutions
  • We can use the technique above to test both dropout and l2 regularization, because dropout seems like intentionally reducing the network capability because you put in a lot of capability in the first place. Sort of cyclical, so L2 should be better.
  • Segments are still bunched together near the origin, and will have a hard time going over the humps. What can we do to catalyze the network to spread out segments more.
  • How do these networks compare to the brain? How many calculations/s, how much energy is consumed? What functions do we expect to be possible from a certain network capability?
In [ ]: