Why ResNet Works?

ResNet is one of the best performance boosters for neural networks. There are lots of articles on the topics but few of them discussed why ResNet works.

I’d like to show why ResNet works concretely.

1. The Problem to Solve

The supervised machine learning problem can be better understood as the function approximation problem. For example,

Given input/output pairs \(\{ x_i, y_i \} \), we want to find a function \(f \) such that \(y_i = f(x_i) \) under a error metric.

There are many ways to select the function \(f \) (or model if you like). The neural network is one of them. The biggest advantage of the neural network is it can approximate super complex functions. However, it also has problems. It overfits the data and It is slow to converge.

Thankfully, ResNet prevents overfitting and trains faster.

Let start with a simple example.

2. An Example

Let’s start with a super simple example. Consider this function approximation problem,

given: x = 1, y = 1 (yes! only one pair of data)

want: find \(f\) such that \(f(x) = y \)

Model 1

Let start with this model,

\(f_1(x) = \theta_4 \theta_3 \theta_2 \theta_1 x \)

\(\theta \)s will be initialized with small value.

And let’s use the L2 loss,

\(cost( \{ \theta \} ) = || f_1(x) – y ||^2 = || f_1(1) – 1 ||^2\)

Do the optimization in Pytorch.

INIT_WEIGHTS = [0.01, -0.01, 0.01, -0.01]
x = 1
y = 1
learning_rate = 1e-2


def test1():
    print("test1 start!")
    weights = [torch.tensor([[INIT_WEIGHTS[i]]], device=device,
                            dtype=dtype, requires_grad=True) for i in range(4)]

    for t in range(100):
        z = x
        for w in weights:
            z = w * z
        y_pred = z

        loss = (y_pred - y).pow(2).sum()
        loss.backward()
        
        if t % 10 == 0:
            print('iter:', t, 'loss:', loss.item(), 'y_pred:', y_pred.item())

        with torch.no_grad():
            for l, w in enumerate(weights):
                w -= learning_rate * w.grad
                w.grad.zero_()
                
                

Oops!! it doesn’t converge

test1 start!
iter: 0 loss: 1.0 y_pred: 9.99999993922529e-09
iter: 10 loss: 1.0 y_pred: 1.0000780648056207e-08
iter: 20 loss: 1.0 y_pred: 1.0001564021422382e-08
iter: 30 loss: 1.0 y_pred: 1.0002346506610138e-08
iter: 40 loss: 1.0 y_pred: 1.0003128991797894e-08
iter: 50 loss: 1.0 y_pred: 1.000391147698565e-08
iter: 60 loss: 1.0 y_pred: 1.0004694850351825e-08
iter: 70 loss: 1.0 y_pred: 1.0005475559182742e-08
iter: 80 loss: 1.0 y_pred: 1.0006258932548917e-08
iter: 90 loss: 1.0 y_pred: 1.0007041417736673e-08

More iterations?

iter: 19400 loss: 0.9999996423721313 y_pred: 1.9921735372463445e-07
iter: 19500 loss: 0.9999996423721313 y_pred: 2.0652491627970448e-07
iter: 19600 loss: 0.9999995231628418 y_pred: 2.142421919870685e-07
iter: 19700 loss: 0.9999995231628418 y_pred: 2.2240011787744152e-07
iter: 19800 loss: 0.9999995231628418 y_pred: 2.310330842192343e-07
iter: 19900 loss: 0.9999995231628418 y_pred: 2.4017822397581767e-07

Nope.

Model 2

Let start another model. Let’s define,

\(g_i(z) = \theta_i z + z\)

then,

\(f_2(x) = g_4(g_3(g_2(g_1(x)))) – 1\)

\(\theta \)s will be initialized with small value. And we will use the L2 loss. The purpose of minus one is to make the initial value of \(f_2\) close to \(f_1\) so that the learning problem is fair.

In Pytorch.

INIT_WEIGHTS = [0.01, -0.01, 0.01, -0.01]
x = 1
y = 1
learning_rate = 1e-2

def test2():
    print("test2 start!")
    weights = [torch.tensor(INIT_WEIGHTS[i], requires_grad=True) for i in range(4)]

    for t in range(100):

        z = x
        for w in weights:
            z = (w + 1.) * z
        y_pred = z - 1

        loss = (y_pred - y).pow(2).sum()
        loss.backward()
        
        if t % 10 == 0:
            print('iter:', t, 'loss:', loss.item(), 'y_pred:', y_pred.item())

        with torch.no_grad():
            for l, w in enumerate(weights):
                w -= learning_rate * w.grad
                w.grad.zero_()

Interestingly, it converges quickly.

test2 start!
iter: 0 loss: 1.0004000663757324 y_pred: -0.00019997358322143555
iter: 20 loss: 0.0005264150095172226 y_pred: 0.9770562648773193
iter: 40 loss: 1.91551094985698e-08 y_pred: 0.9998615980148315
iter: 60 loss: 6.963318810448982e-13 y_pred: 0.9999991655349731
iter: 80 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104
iter: 100 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104
iter: 120 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104
iter: 140 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104
iter: 160 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104
iter: 180 loss: 1.4210854715202004e-14 y_pred: 0.9999998807907104

Shocking result! What’s going on here???

3. The ResNet

The above example shows the gist of ResNet.

Let \(h_i(z) = \theta_i z\),

In model 1, \(f_1(x) = h_4(h_3(h_2(h_1(z))))\).

In model 2, Let \(g_i(z) = \theta_i z + z\). \(f_2(x) = g_4(g_3(g_2(g_1(x)))) \)

The model2 is simply model1 with residual connections.

The Residual Block

But… why ResNet is better?

4. Why ResNet is Better?

Let’s consider the model2.

\(f_2(x) = g_4(g_3(g_2(g_1(x)))) \) where \(g_i(z) = \theta_i z + z\) .

Rewrite \(g_i(z) = (\theta_i + 1) z\),

\(f_2(x) = (\theta_4 + 1) (\theta_3 + 1) (\theta_2 + 1) (\theta_1 + 1) x\)

It’s a polynomial. Let’s expand it.

\(f_2(x) = \theta_1 \theta_2 \theta_3 \theta_4 x \\
+ \theta_2 \theta_3 \theta_4 x + \theta_1 \theta_3 \theta_4 x + \theta_1 \theta_2 \theta_4 x + \theta_1 \theta_2 \theta_3 x \\
+ \theta_1 \theta_2 x + \theta_1 \theta_3 x + \theta_1 \theta_4 x + \theta_2 \theta_3 x + \theta_2 \theta_4 x + \theta_3 \theta_4 x \\
+ \theta_1 x + \theta_2 x + \theta_3 x + \theta_4 x
+ x \)

Now, we can do some analysis.

Less prone to overfitting

Let’s compare \(f_2(x)\) with \(f_1(x) = \theta_1 \theta_2 \theta_3 \theta_4 x \).

We can see \(f_2(x)\) contains \(f_1(x)\). In fact, we can see that \(f_2(x)\) is a polynomial expansion of each layer. Whereas \(f_2(x)\) is the highest term in the polynomial expansion of \(f_2(x)\).

overfitting & underfitting from Andrew Ng

We can argue that, for \(f_2(x)\), the lower order terms help to prevent overfitting. However, for \(f_2(x) \), the only term \( \theta_1 \theta_2 \theta_3 \theta_4 x \) suffers from overfitting.

Better Gradient Flow

First, let’s look at

\(f_1(x) = \theta_1 \theta_2 \theta_3 \theta_4 x \),

Consider the gradient w.r.t a parameter,

\(\frac {\partial f_1}{\partial \theta_1} = \theta_2 \theta_3 \theta_4 \theta_4 x \)

Given our random small number initialization where,

INIT_WEIGHTS = [0.01, -0.01, 0.01, -0.01]
x = 1
...

We can compute the gradient,

\(\frac {\partial f_1}{\partial \theta_1} = 1e-6 \)

A tiny number!

This is the vanishing gradient problem. The gradient is the multiplication of many small numbers. With more small numbers, the gradient shrink exponentially.

Recall from the previous discussion, even if for a super simple problem, the training didn’t converge for \(f_1(x)\)!

Now, let’s consider \(f_2(x)\),

\(f_2(x) = \theta_1 \theta_2 \theta_3 \theta_4 x \\
+ \theta_2 \theta_3 \theta_4 x + \theta_1 \theta_3 \theta_4 x + \theta_1 \theta_2 \theta_4 x + \theta_1 \theta_2 \theta_3 x \\
+ \theta_1 \theta_2 x + \theta_1 \theta_3 x + \theta_1 \theta_4 x + \theta_2 \theta_3 x + \theta_2 \theta_4 x + \theta_3 \theta_4 x \\
+ \theta_1 x + \theta_2 x + \theta_3 x + \theta_4 x
+ x \)

Consider it’s gradient w.r.t \(\theta_1\),

\(\frac {\partial f_2}{\partial \theta_1} = x + \theta_2 x + \theta_3 x + \theta_4 x + …\)

Given our small value initialization,

\(\frac {\partial f_2}{\partial \theta_1} \approx x = 1\)

As we can see, in \(f_2\), the gradient from lower order terms are reasonably large, and it is sufficient for the gradient desent algorithm.

Another equivalent way to see this property is to look at the computational graph directly. We can argument due to the residual connections, gradients can flow from cost to any parameters directly.

Plot the cost

We can plot the cost of \(f_1\) and \(f_2\) along a parameter to see the gradient visually.

code

def plot_loss():
    weights = INIT_WEIGHTS[:]
    
    def loss1(x, W):
        w1,w2,w3,w4 = W
        y_pred = w1*w2*w3*w4*x
        loss = (y_pred - y) ** 2
        return loss

    def loss2(x, W):
        z = x
        for w in W:
            z = (w + 1) * z
        y_pred = z
        loss = (y_pred - y) ** 2
        return loss

    loss1_all = []
    loss2_all = []
    w1_range = np.linspace(-3, 3, 100)
    for w1 in w1_range:
        weights[0] = w1
        loss1_all.append(loss1(x, weights))
        loss2_all.append(loss2(x, weights))
    
    plt.plot(w1_range, loss1_all, 'r', label='f1 cost')
    plt.plot(w1_range, loss2_all, 'b', label='f2 cost')
    plt.xlabel('theta 1')
    plt.ylabel('cost')
    plt.title('cost function of f1 and f2 along theta1')
    plt.legend()
    plt.show()

The Polynomial Expansion

We can do this function approximation using \(f_2\) in polynomial directly.

def test3():
    print("test3 start!")
    weights = [torch.tensor(INIT_WEIGHTS[i], requires_grad=True) for i in range(4)]

    w1, w2, w3, w4 = weights

    for t in range(max_iters):

        y_pred = w1*w2*w3*w4*x \
            + w1*w2*w3*x + w1*w2*w4*x + w1*w3*w4*x + w2*w3*w4*x\
            + w1*w2*x + w1*w3*x + w1*w4*x + w2*w3*x + w2*w4*x + w3*w4*x\
            + w1*x + w2*x + w3*x + w4*x + x
        y_pred = y_pred - 1
        loss = (y_pred - y).pow(2).sum()
        loss.backward()
        
        if t % print_interval == 0:
            print('iter:', t, 'loss:', loss.item(), 'y_pred:', y_pred.item())


        with torch.no_grad():
            for l, w in enumerate(weights):
                w -= learning_rate * w.grad
                w.grad.zero_()

The output is the same as \(f_2\).

test3 start!
iter: 0 loss: 1.0004000663757324 y_pred: -0.00019997358322143555
iter: 20 loss: 0.0005264040664769709 y_pred: 0.9770565032958984
iter: 40 loss: 1.9188121314073214e-08 y_pred: 0.999861478805542
iter: 60 loss: 6.963318810448982e-13 y_pred: 0.9999991655349731
iter: 80 loss: 0.0 y_pred: 1.0
iter: 100 loss: 0.0 y_pred: 1.0
iter: 120 loss: 0.0 y_pred: 1.0
iter: 140 loss: 0.0 y_pred: 1.0
iter: 160 loss: 0.0 y_pred: 1.0
iter: 180 loss: 0.0 y_pred: 1.0

5. Discussion

The simple example covers the gist of the ResNet. But there are still gaps. Let’s discuss them.

Multiple Input & Output Function?

Well, In terms of gradient, multiple input & output functions behave the same as single input & output functions.

How about Nonlinear?

For the Gradient Descent algorithm, we linearized the function, and go downhill. We can argue that, after linearization, the previous vanishing gradient and the polynomial analysis still hold.

Moreover, the vanishing gradient problem directly led to how to choose the nonlinear layer.

Consider the Sigmoid function,

https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e

Look at the derivative of the sigmoid, the value is very small! The largest value is only 0.25. For a 4 layer network, the gradient w.r.t first layer is only \((0.25)^4 \approx 0.004 \).

Consider the Relu function,

https://www.machinecurve.com/index.php/2020/01/24/overview-of-activation-functions-for-neural-networks/

Th For a 4 layer network, the gradient w.r.t first layer can be \((1.)^4 = 1\). Much better!

Overcoming vanishing gradient by vanishing gradient

In NLP, Because we want to map all previous lauguage token to an output, vanishing gradient is a big issue. To overcome the vanishing gradient problem, people apply the vanishing gradient!

Consider a GRU unit,

The output of GRU \(h_t\) is a weighted average of previous output \(h_{t-1}\) and current activation \(\hat{h_t}\). The weight \(z_t\) is given by a signmoid function.

Let’s consider 3 cases.

Case1. The sigmoid \(z_t\) is not 1 or 0, we have a chain of multiplications \(h_t = (1- z_t) (1 – z_{t-1}) (1 – z_{t-2}) …. h_0\). It sufers from vanishing gradient.

Case2. The sigmoid \(z_t\) is 0. We have \(h_t = h_{t-1}\). It’s a residual connection! The gradient can flow directly to previous units.

Case3. The sigmoid \(z_t\) is 1. We have \(h_t = \hat{h_t}\). The gradient flow to the input \(x_t\). Since the connects are short. It doesn’t suffer from vanishing gradient.

In sum, when the gradient of the sigmoid \(z_t\) is 0, the vanishing gradient problem of the sequence is gone. Cool!!

More on Polynomial Expansion from ResNet

From previous discussion, we can see that using the residual connection is equivalent to doing a polynomial expansion of the original layers. For the polynomial, there are higher order terms and lower order terms.

Fact1: The lower order terms are easy to train than the higher order terms. e.g. The gradient of the lower order terms are exponentially larger than higher order terms.

Fact2: The higher order terms are more expressive. For complex mapping in real applications, the higher order terms play a more important role.

The 2 facts implies, near convergence ,the learning is dominant by noise from lower order terms but what we want is to update the higher order terms.

I guess what we can do is: freeze lower order terms by removing the residual connections. By doing so, we limit the noise from lower order terms and we can increase the learning rate exponentially. In the end, (I hope) we can train the network exponentially faster.

One thought on “Why ResNet Works?

Leave a Reply