Road to ML Engineer #15 - Unstable Gradients

Last Edited: 8/27/2024

The blog post introduces unstable gradient problem in deep learning.

ML

Unstable Gradients

Regardless of which optimizer we choose, we need to compute the gradients of a loss function with respect to parameters via backpropagation. This involves multiplying the partial derivatives dzhdhh1dhh1dzh1\frac{dz_{h}}{dh_{h-1}}\frac{dh_{h-1}}{dz_{h-1}}. The former part is whw_h, and the latter part is hh1(1hh1)h_{h-1}(1-h_{h-1}) for the sigmoid function.

If these values are below 1, the gradients will get exponentially smaller or vanish as we propagate back, which can lead to virtually no adjustments to the weights. If these values are above 1, the gradients will exponentially increase or explode as we propagate back, causing overshooting. In both cases, the model will struggle to adequately adjust the weights and learn. Hence, we refer to these as the vanishing/exploding gradient problem or unstable gradient problem, and we are constantly battling this issue when building deep neural networks.

Activation Function

One approach to solving the unstable gradient problem is through the activation function. We have been using the sigmoid function to mimic neurons in our brain, but its derivative has the following shape:

The derivatives are highest at zero, but they are still smaller than 1. The derivatives encourage models to have activations of either 0 or 1, causing the derivatives to get closer to 0. Hence, the sigmoid function is prone to the vanishing gradient problem. Instead of using the sigmoid function, we can use an alternative non-linear function with a better derivative shape, such as the ReLU function.

ReLU(x)=max(0,x) \text{ReLU}(x) = max(0, x)

The ReLU function has a derivative of 1 for positive values and 0 for negative values. It is easy to compute and less prone to overfitting. The following is the implementation in the FNNClassifier class:

# Forward Pass
def relu(self, X):
    return np.maximum(0, X)
 
# Backward Pass
def fit(self, X, y, epochs=100, verbose=True):
    ...
    # Backpropagate the error
    if i > 0:
        # dz_t/dh_t-1 = self.W[i].T
        # dh_t-1/dz_t-1 = (activations[i] > 0).astype(int) for h = relu
        delta = np.matmul(delta, self.W[i].T) * (activations[i] > 0).astype(int)
    ...

The ReLU function is the default activation function for all the hidden layers. There are other functions like Leaky ReLU, which have different slopes for positive and negative values to retain non-linearity and gradients for negative values. We can use these other non-linear activation functions to prioritize better learning rather than imitating real brains.

Weight Initalization

For both sigmoid and ReLU activation functions, their derivatives depend on the values of activations (hh1(1hh1)h_{h-1}(1-h_{h-1}) for sigmoid and hh1>0h_{h-1} > 0 for ReLU). Therefore, we want the variance of the activations across all layers to be the same, so that the majority of the gradients will propagate through non-extreme values, similar to the input layer, and thus prevent the unstable gradient problem. Assuming that the activation is linear (sigmoid is linear around the center and ReLU is linear for positive values), and biases are zero, the variance of hhh_h would look like the following:

Var(hh)=Var(i=1nh1whihh1) Var(h_h) = Var(\sum_{i=1}^{n_{h-1}} w_{hi} h_{h-1})

Since weights and zz are mutually independent, we can rearrange this as:

Var(hh)=i=1nh1Var(whihh1)=nh1Var(whihh1) Var(h_h) = \sum_{i=1}^{n_{h-1}} Var (w_{hi} h_{h-1}) \\ = n_{h-1} Var (w_{hi} h_{h-1})

The variance of the product of two independent variables gives us:

Var(hh)=nh1(E(whi)2Var(hh1)+Var(whi)E(hh1)2+Var(whi)Var(hh1)) Var(h_h) = n_{h-1} (E(w_{hi})^2 Var(h_{h-1})+ Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1}))

Provided that the weights ww have zero mean:

Var(hh)=nh1(Var(whi)E(hh1)2+Var(whi)Var(hh1))=nh1(Var(whi)E(hh1)2+Var(whi)(E((hh1)2)E(hh1)2))=nh1Var(whi)E((hh1)2) Var(h_h) = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1})) \\ = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})(E((h_{h-1})^2) - E(h_{h-1})^2)) \\ = n_{h-1} Var(w_{hi}) E((h_{h-1})^2)

If we assume that hh1h_{h-1} has zero mean, E(hh12)=Var(hh1)E(h_{h-1}^2) = Var(h_{h-1}) stands for a linear activation function. However, since ReLU is zero for negative values, E(hh12)E(h_{h-1}^2) must be halved, making it equal to 12Var(hh1)\frac{1}{2}Var(h_{h-1}) for ReLU. Hence, the variance for ReLU can be expressed as:

Var(hh)=nh1(Var(whi)E(hh1)2+Var(whi)Var(hh1))=nh1(Var(whi)E(hh1)2+Var(whi)(E((hh1)2)E(hh1)2))=12nh1Var(whi)Var(hh1) Var(h_h) = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1})) \\ = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})(E((h_{h-1})^2) - E(h_{h-1})^2)) \\ = \frac{1}{2} n_{h-1} Var(w_{hi}) Var(h_{h-1})

From the above, to keep the variance the same across layers, Var(hh)=Var(hh1)Var(h_h) = Var(h_{h-1}), we need 12nhVar(whi)\frac{1}{2} n_h Var(w_{hi}) to be 1. Hence, we can initialize the weight such that Var(whi)Var(w_{hi}) is 2nh1\frac{2}{n_{h-1}} for ReLU activation functions. This weight initialization method for ReLU is called He initialization. For the sigmoid function, we assume a linear activation function and set weights to have a variance equal to 1nh1\frac{1}{n_{h-1}}, which is called Xavier initialization.

Both methods are examples of weight initialization, which aim to solve the unstable gradient problem by keeping the variance of activations consistent and allowing dhh1dzh1\frac{dh_{h-1}}{dz_{h-1}} to spread sufficiently. This also helps with the exploding gradient problem by preventing the weights and dzhdhh1\frac{dz_{h}}{dh_{h-1}} from becoming too large. The following is the implementation of weight initialization in the FNNClassifier class:

def initialize_weight(self, input_dim, hidden_dims, output_dim):
    self.W = []
    self.b = []
    
    # He initialization for ReLU
    for i in range(len(hidden_dims)):
        if i == 0:
            self.W.append(np.random.normal(0, (2/input_dim)**0.5, size=(input_dim, hidden_dims[i])))
        else:
            self.W.append(np.random.normal(0, (2/hidden_dims[i-1])**0.5, size=(hidden_dims[i-1], hidden_dims[i])))
        self.b.append(np.zeros(hidden_dims[i]))
 
    # Xavier initialization for Sigmoid or Softmax
    self.W.append(np.random.normal(0, (1/hidden_dims[len(hidden_dims)-1])**0.5, size=(hidden_dims[len(hidden_dims)-1], output_dim)))
    self.b.append(np.zeros(output_dim))

The above method is called inside the __init__ method in the class.

L1 & L2 Regularization

While regularization is primarily used to prevent models from overfitting to the data, it also indirectly contributes to encouraging the weights to be zero or close to zero, and helps prevent the exploding gradient problem. We can use either L1 or L2 regularization, which adds a corresponding penalty term to the loss function. (If you want to learn more about them, I recommend checking out "Road to ML Engineer #7 - Regularization.")

The L2 penalty rate in the context of optimizers is called weight decay. The following shows the code implementation of weight decay in FNNClassifier.

def fit(self, X, y, epochs=100, verbose=True):
    ...
 
    for i in range(len(self.W) - 1, -1, -1):
        grad_W = np.matmul(activations[i].T, delta)
        if (self.weight_decay):
            grad_W += self.weight_decay * self.W[i]
 
    ...

If this is applied in SGD, weight decay will be the exact same process as L2 regularization. However, weight decay will not be equivalent to L2 regularization if it is applied in the context of Adam, since the first and second moments both depend on the gradients after regularization. The modified version of the Adam optimizer, AdamW, tries to isolate or decouple the adaptive momentum and regularization by applying weight decay to the weights rather than the gradients. The following is the implementation of AdamW in FNNClassifier.

def __init__(self, input_dim, hidden_dims, output_dim, lr=0.001, 
                 batch_size=16, beta_1=0.9, beta_2=0.99, weight_decay=None, optimizer="SGD"):
        ...
 
        # Initialize Adam & AdamW variables
        if self.optimizer == "Adam" or self.optimizer == "AdamW":
            self.epsilon = 1e-8
            self.m_W = [np.zeros_like(w) for w in self.W]
            self.m_b = [np.zeros_like(b) for b in self.b]
            self.s_W = [np.zeros_like(w) for w in self.W]
            self.s_b = [np.zeros_like(b) for b in self.b]
 
def fit(self, X, y, epochs=100, verbose=True):
    # Adam-specific parameter updates
    if self.optimizer == "Adam" or self.optimizer == "AdamW":
        t = epoch + 1
        self.m_W[i] = self.beta_1 * self.m_W[i] + (1 - self.beta_1) * grad_W
        self.s_W[i] = self.beta_2 * self.s_W[i] + (1 - self.beta_2) * (grad_W ** 2)
        self.m_b[i] = self.beta_1 * self.m_b[i] + (1 - self.beta_1) * grad_b
        self.s_b[i] = self.beta_2 * self.s_b[i] + (1 - self.beta_2) * (grad_b ** 2)
 
        m_W_hat = self.m_W[i] / (1 - self.beta_1 ** t)
        s_W_hat = self.s_W[i] / (1 - self.beta_2 ** t)
        m_b_hat = self.m_b[i] / (1 - self.beta_1 ** t)
        s_b_hat = self.s_b[i] / (1 - self.beta_2 ** t)
 
        if self.optimizer == "AdamW" and self.weight_decay:
            self.W[i] -= self.lr * self.weight_decay * self.W[i]
            self.b[i] -= self.lr * self.weight_decay * self.b[i]
        self.W[i] -= self.lr * m_W_hat / (np.sqrt(s_W_hat) + self.epsilon)
        self.b[i] -= self.lr * m_b_hat / (np.sqrt(s_b_hat) + self.epsilon)

By using all of the above techniques against the unstable gradient problem—ReLU activation, weight initialization, and regularization with AdamW—the learning rate will look like the one below.

AdamW Loss

The learning curve now has an ideal shape, indicating gradual and smooth learning achieved through these countermeasures against the unstable gradient problem.

Dropout & Gradient Clipping

Another approach to regularization specific to neural networks is dropout regularization, where we drop out a proportion of the activations from the previous layers, randomly picked for every epoch during training. To account for the deactivated neurons, the remaining active neurons are scaled up by the reciprocal of the probability of neurons dropped. (If 25% is dropped, the active neurons are scaled up by 4.) The goal is to introduce random noise to the data to prevent overfitting and make the model generalizable to unseen data. It can also indirectly help with the vanishing gradient problem by scaling up some activations and gradients dhh1dzh1\frac{dh_{h-1}}{dz_{h-1}}.

The simplest possible solution to the exploding gradient problem is to set a maximum threshold for the gradient, which is called gradient clipping. I will not introduce both in detail here and will not implement them for FNNClassifier, but I encourage you to try them out yourself.

Batch Normalization

Weight initialization aims to keep the distributions of the activations consistent with the distribution of the input layer, which is ideally assumed to have a mean of 0 and adequate variance (typically set to 1). This implies the importance of data normalization during preprocessing, which can be done using the following:

μ=1ni=1nxiσ2=1ni=1n(xiμ)2x^=xiμσ2+ϵ \mu = \frac{1}{n} \sum_{i=1}^{n} x_i \\ \sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 \\ \hat{x} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

, where ϵ\epsilon is a small value to prevent division by zero. Weight initialization initializes weights in a specific way to keep the variance of the distributions consistent, but there’s no reason we can’t apply the same noramlization method as the above to batches of activations at every layer to maintain the variance. This method is called batch normalization, which can be expressed as follows:

B={x1,x2,...,xm}μB=1mi=1mxiσB2=1mi=1m(xiμB)2x^=xiμσB2+ϵyi=γx^+β B = \{x_1, x_2, ..., x_m\} \\ \mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i \\ \sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2 \\ \hat{x} = \frac{x_i - \mu}{\sqrt{\sigma_B^2 + \epsilon}} \\ y_i = \gamma \hat{x} + \beta

The normalization steps are almost identical, except that we apply them to batches and use extra linear functions with learnable parameters γ\gamma and β\beta. The final linear function learns the best mean to use for every batch rather than setting it to 0. This intuitive and simple approach has proven incredibly effective in stabilizing gradients while maintaining learning speed, making the above techniques unnecessary. It’s also one of the most robust approaches because batch normalization doesn’t have hyperparameters nor depends on the choice of activation functions. The following is the code implementation:

def batch_norm_forward(self, X, gamma, beta, running_mean, running_var, momentum=0.9, train=True):
    if train:
        batch_mean = np.mean(X, axis=0)
        batch_var = np.var(X, axis=0)
        X_normalized = (X - batch_mean) / np.sqrt(batch_var + self.epsilon)
        out = gamma * X_normalized + beta
        
        # Update running statistics
        running_mean = momentum * running_mean + (1 - momentum) * batch_mean
        running_var = momentum * running_var + (1 - momentum) * batch_var
    else:
        # Use running statistics for inference
        X_normalized = (X - running_mean) / np.sqrt(running_var + self.epsilon)
        out = gamma * X_normalized + beta
 
    return out, running_mean, running_var
 
def batch_norm_backward(self, dout, X, gamma, beta, mean, var):
    N = X.shape[0]
 
    X_normalized = (X - mean) / np.sqrt(var + self.epsilon)
    dgamma = np.sum(dout * X_normalized, axis=0)
    dbeta = np.sum(dout, axis=0)
 
    dX_normalized = dout * gamma
    dvar = np.sum(dX_normalized * (X - mean) * -0.5 * np.power(var + self.epsilon, -1.5), axis=0)
    dmean = np.sum(dX_normalized * -1 / np.sqrt(var + self.epsilon), axis=0) + dvar * np.sum(-2 * (X - mean), axis=0) / N
    dX = dX_normalized / np.sqrt(var + self.epsilon) + dvar * 2 * (X - mean) / N + dmean / N
 
    return dX, dgamma, dbeta
 
FNNClassifer.batch_norm_forward = batch_norm_forward
FNNClassifer.batch_norm_backward = batch_norm_backward

The above code represents the forward and backward passes of the batch normalization layer. The backward pass is relatively straightforward, except for the computation of the derivative of the loss with respect to the input before normalization, which is needed for further backpropagation. Unfortunately, this computation is beyond the scope of this article (it involves the vector equivalent of the product rule). If you’re interested in the math, I recommend trying to derive it yourself or checking out the article "Deriving Batch-Norm Backprop Equations" by Yeh, C. (2024).

Final Results

After implementing optimizers, activation functions, weight initialization, L2 regularization, and batch normalization, the FNNClassifier ended up looking like the code in the appendix at the bottom of this article. It ended up being quite robust to any set of hyperparameters and datasets. However, there are some parts that are not customizable, such as activation functions for each hidden layer, regularization and weight initialization to use for each hidden layer, and where to use batch normalization.

This can be improved by modularizing each layer as a Layer class with layer-wise configurations and combining them to create a general Model that performs backpropagation using the selected optimizer and loss. Such modularizations for artificial neural networks and optimizations for every component of the models are already implemented in several frameworks, such as TensorFlow, PyTorch, and MXNet. Therefore, from this article onwards, we will mainly use these libraries to implement neural network models instead of building them from scratch.

Resources