MLエンジニアへの道 #15 - 勾配消失・勾配爆発

Last Edited: 8/27/2024

このブログ記事では、ディープラーニングにおける勾配消失・勾配爆発問題を紹介します。

ML

勾配消失・勾配爆発

どの最適化アルゴリズムを選んだとしても、コスト関数の勾配をパラメータに対して計算する必要があります。 この過程は、逆伝播によって行われますが、その際には、偏導関数dzhdhh1dhh1dzh1\frac{dz_{h}}{dh_{h-1}}\frac{dh_{h-1}}{dz_{h-1}}を掛け算します。 前者の部分は whw_hであり、後者の部分はシグモイド関数の場合hh1(1hh1)h_{h-1}(1-h_{h-1})です。

これらの値が1未満の場合、逆伝播の際に勾配が指数関数的に小さくなり、最終的には消失してしまう可能性があります。これにより、 重みの調整がほとんど行われなくなります。逆に、これらの値が1を超えると、勾配が指数関数的に大きくなり、爆発してしまい、 オーバーシュートを引き起こす可能性があります。どちらの場合でも、モデルは重みを適切に調整して学習するのが難しくなります。 そのため、これを勾配消失・勾配爆発問題と呼び、ディープニューラルネットワークを構築する際には常にこの問題と戦わなければなりません。

活性化関数

不安定な勾配問題を解決するための一つのアプローチとして、活性化関数があります。私たちは脳のニューロンを模倣するためにシグモイド関数を使用してきましたが、 その導関数は以下のような形を持っています。

導関数はゼロの時に最大となりますが、それでも1未満です。導関数は、モデルが0または1の活性化を持つことを促し、 その結果、導関数が0に近づいていきます。したがって、シグモイド関数は勾配消失問題に陥りやすいです。この問題を 回避するために、シグモイド関数の代わりに、より良い導関数の形を持つ別の非線形関数、例えばReLU関数を使用 することができます。

ReLU(x)=max(0,x) \text{ReLU}(x) = max(0, x)

ReLU関数は、正の値に対して導関数が1で、負の値に対して導関数が0であるため、計算が簡単で、過学習のリスクが低いです。 以下は FNNClassifer クラスでの実装例です。

# Forward Pass
def relu(self, X):
    return np.maximum(0, X)
 
# Backward Pass
def fit(self, X, y, epochs=100, verbose=True):
    ...
    # Backpropagate the error
    if i > 0:
        # dz_t/dh_t-1 = self.W[i].T
        # dh_t-1/dz_t-1 = (activations[i] > 0).astype(int) for h = relu
        delta = np.matmul(delta, self.W[i].T) * (activations[i] > 0).astype(int)
    ...

上でReLU関数は、全ての隠れ層におけるデフォルトの活性化関数とされています。他にも、負の値に対しても非線形性と勾配 を保つために異なる傾きを持つLeaky ReLUなどの関数があります。これらの他の非線形活性化関数を使用することで、 脳を模倣することよりも、より良い学習を優先させることができます。

重みの初期化

シグモイド関数とReLU関数の両方において、それらの導関数は活性化の値 (シグモイド関数はhh1(1hh1)h_{h-1}(1-h_{h-1})、ReLU関数はhh1>0h_{h-1} > 0) に依存します。したがって、すべての層にわたって活性化の分散が入力層のように極端ではない値の分散と同じになるようにして、勾配消失・勾配爆発問題 を防ぎたいです。活性化が線形である(実際シグモイド関数は中心付近で線形であり、ReLU関数は正の値で線形である)と仮定し、バイアスがゼロであるとすると、 hhh_h の分散は次のように表現できます:

Var(hh)=Var(i=1nh1whihh1) Var(h_h) = Var(\sum_{i=1}^{n_{h-1}} w_{hi} h_{h-1})

重みと zz は互いに独立しているため、これを次のように表せます:

Var(hh)=i=1nh1Var(whihh1)=nh1Var(whihh1) Var(h_h) = \sum_{i=1}^{n_{h-1}} Var (w_{hi} h_{h-1}) \\ = n_{h-1} Var (w_{hi} h_{h-1})

2つの独立した変数の積の分散は次のようになります:

Var(hh)=nh1(E(whi)2Var(hh1)+Var(whi)E(hh1)2+Var(whi)Var(hh1)) Var(h_h) = n_{h-1} (E(w_{hi})^2 Var(h_{h-1})+ Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1}))

重み ww の平均がゼロであると仮定すると:

Var(hh)=nh1(Var(whi)E(hh1)2+Var(whi)Var(hh1))=nh1(Var(whi)E(hh1)2+Var(whi)(E((hh1)2)E(hh1)2))=nh1Var(whi)E((hh1)2) Var(h_h) = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1})) \\ = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})(E((h_{h-1})^2) - E(h_{h-1})^2)) \\ = n_{h-1} Var(w_{hi}) E((h_{h-1})^2)

hh1h_{h-1} がゼロ平均であると仮定すると、E(hh12)=Var(hh1)E(h_{h-1}^2) = Var(h_{h-1}) は線形活性化関数に対応します。ただし、 ReLUは負の値に対してゼロであるため、E(hh12)E(h_{h-1}^2) は半分にする必要があり、これによりReLUの場合、活性化の値の分散は 12Var(hh1)\frac{1}{2}Var(h_{h-1}) と等しくなります。したがって、ReLUの時の分散は次のように表現できます:

Var(hh)=nh1(Var(whi)E(hh1)2+Var(whi)Var(hh1))=nh1(Var(whi)E(hh1)2+Var(whi)(E((hh1)2)E(hh1)2))=12nh1Var(whi)Var(hh1) Var(h_h) = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})Var(h_{h-1})) \\ = n_{h-1} (Var(w_{hi}) E(h_{h-1})^2 + Var(w_{hi})(E((h_{h-1})^2) - E(h_{h-1})^2)) \\ = \frac{1}{2} n_{h-1} Var(w_{hi}) Var(h_{h-1})

上記のように、層間で分散を同じに保つためには、Var(hh)=Var(hh1)Var(h_h) = Var(h_{h-1}) を満たすために 12nhVar(whi)\frac{1}{2} n_h Var(w_{hi}) が1である必要があります。したがって、ReLU活性化関数に対して Var(whi)Var(w_{hi})2nh1\frac{2}{n_{h-1}} になるように重みを初期化できます。 この重み初期化手法はHe初期化と呼ばれます。シグモイド関数の場合、線形活性化関数を仮定し、重みの分散を 1nh1\frac{1}{n_{h-1}} に設定します。 これはXavier初期化と呼ばれます。

これらの両方は、重みの初期化の例であり、活性化の分散を維持し、dhh1dzh1\frac{dh_{h-1}}{dz_{h-1}} が十分に広がるようにすることで 勾配問題の解決を目論みます。また、重みとdzhdhh1\frac{dz_{h}}{dh_{h-1}}が大きくなりすぎるのを防ぐことで、間接的に勾配爆発問題に も対処できます。以下は、FNNClassifierクラスにおける重み初期化の実装例です:

def initialize_weight(self, input_dim, hidden_dims, output_dim):
    self.W = []
    self.b = []
    
    # He initialization for ReLU
    for i in range(len(hidden_dims)):
        if i == 0:
            self.W.append(np.random.normal(0, (2/input_dim)**0.5, size=(input_dim, hidden_dims[i])))
        else:
            self.W.append(np.random.normal(0, (2/hidden_dims[i-1])**0.5, size=(hidden_dims[i-1], hidden_dims[i])))
        self.b.append(np.zeros(hidden_dims[i]))
 
    # Xavier initialization for Sigmoid or Softmax
    self.W.append(np.random.normal(0, (1/hidden_dims[len(hidden_dims)-1])**0.5, size=(hidden_dims[len(hidden_dims)-1], output_dim)))
    self.b.append(np.zeros(output_dim))

上記のメソッドは、クラスの __init__メソッド内で呼び出されます。

L1およびL2正則化

正則化は主にモデルがデータに対して過学習するのを防ぐために使用されますが、それは間接的に重みがゼロまたはゼロに近くなる ことを促し、勾配爆発問題を防ぐのにも役立ちます。L1またはL2正則化のどちらかを使用することができ、これにより損失関数に 対応するペナルティ項が追加されます。(詳細を知りたい場合は、「Road to ML Engineer #7 - Regularization」をチェックすることをお勧めします。)

最適化アルゴリズムの文脈におけるL2ペナルティ率は、重み減衰(weight decay)と呼ばれます。以下は、 FNNClassifierでの重み減衰のコード実装を示しています。

def fit(self, X, y, epochs=100, verbose=True):
    ...
 
    for i in range(len(self.W) - 1, -1, -1):
        grad_W = np.matmul(activations[i].T, delta)
        if (self.weight_decay):
            grad_W += self.weight_decay * self.W[i]
 
    ...

これがSGDに適用される場合、重み減衰はL2正則化と全く同じプロセスになります。しかし、Adamに適用された場合、 重み減衰はL2正則化と同等ではなくなります。これは、ファーストモーメントとセカンドモーメントの両方が正則化後 の勾配に依存するためです。Adamの改良版であるAdamWは、適応モーメンタムと正則化を切り離すために、 勾配ではなく重みに重み減衰を適用します。以下は、FNNClassifierにおけるAdamWの実装です。

def __init__(self, input_dim, hidden_dims, output_dim, lr=0.001, 
                 batch_size=16, beta_1=0.9, beta_2=0.99, weight_decay=None, optimizer="SGD"):
        ...
 
        # Initialize Adam & AdamW variables
        if self.optimizer == "Adam" or self.optimizer == "AdamW":
            self.epsilon = 1e-8
            self.m_W = [np.zeros_like(w) for w in self.W]
            self.m_b = [np.zeros_like(b) for b in self.b]
            self.s_W = [np.zeros_like(w) for w in self.W]
            self.s_b = [np.zeros_like(b) for b in self.b]
 
def fit(self, X, y, epochs=100, verbose=True):
    # Adam-specific parameter updates
    if self.optimizer == "Adam" or self.optimizer == "AdamW":
        t = epoch + 1
        self.m_W[i] = self.beta_1 * self.m_W[i] + (1 - self.beta_1) * grad_W
        self.s_W[i] = self.beta_2 * self.s_W[i] + (1 - self.beta_2) * (grad_W ** 2)
        self.m_b[i] = self.beta_1 * self.m_b[i] + (1 - self.beta_1) * grad_b
        self.s_b[i] = self.beta_2 * self.s_b[i] + (1 - self.beta_2) * (grad_b ** 2)
 
        m_W_hat = self.m_W[i] / (1 - self.beta_1 ** t)
        s_W_hat = self.s_W[i] / (1 - self.beta_2 ** t)
        m_b_hat = self.m_b[i] / (1 - self.beta_1 ** t)
        s_b_hat = self.s_b[i] / (1 - self.beta_2 ** t)
 
        if self.optimizer == "AdamW" and self.weight_decay:
            self.W[i] -= self.lr * self.weight_decay * self.W[i]
            self.b[i] -= self.lr * self.weight_decay * self.b[i]
        self.W[i] -= self.lr * m_W_hat / (np.sqrt(s_W_hat) + self.epsilon)
        self.b[i] -= self.lr * m_b_hat / (np.sqrt(s_b_hat) + self.epsilon)

勾配消失・勾配爆発に対して、ReLU活性化関数、重みの初期化、AdamWを使用した正則化など、上記のすべての技術を 使用することで、学習は以下のように推移するようになります。

AdamW Loss

学習曲線は理想的な形をしており、これらの勾配消失・勾配爆発問題への対策により、段階的かつスムーズな学習が実現されている ことを示しています。

ドロップアウトと勾配クリッピング

ニューラルネットワークに特有の正則化アプローチとして、ドロップアウト正則化があります。これは、前の層からの 各エポックごとにニューロンの活性の一部をランダムに選んで無効化する方法です。無効化されたニューロンを補うために、 残りの有効なニューロンは、ドロップアウトされる確率の逆数でスケーリングされます。(25%がドロップアウトされた場合、 有効なニューロンは4倍にスケーリングされます。)これは、データにランダムなノイズを導入して過学習を防ぎ、 モデルが未見のデータに対して汎化できるようにすることを目的としています。また、活性や勾配dhh1dzh1\frac{dh_{h-1}}{dz_{h-1}}を スケーリングアップすることで、間接的に勾配消失問題の助けにもなります。

勾配爆発問題に対する最も簡単な解決策は、勾配に最大の閾値を設定することで、これは勾配クリッピングと呼ばれます。 ここでは詳細を紹介せず、またFNNClassifierに対して実装しませんが、ぜひご自身で試してみることをお勧めします。

バッチ正規化

重みの初期化は、活性化の分布を入力層の分布と一致させることを目指しています。入力層は、理想的には平均0で適切な分散 (通常は1に設定される)を持つと仮定されます。これにより、前処理中のデータの正規化の重要性が浮き彫りになります。 データの正規化は一般に以下のように行うことができます。

μ=1ni=1nxiσ2=1ni=1n(xiμ)2x^=xiμσ2+ϵ \mu = \frac{1}{n} \sum_{i=1}^{n} x_i \\ \sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 \\ \hat{x} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}

ここで、ϵ\epsilonは分母がゼロになるのを防ぐための小さな値です。重みの初期化は、分布の分散を一定に保つように重みを 初期化しますが、分散を保つために各層の活性化のバッチに上と同じようなシンプルな正規化の方法を適用することを妨げるもの はありません。この方法はバッチ正規化 (Batch Normalization)と呼ばれ、以下のように表現されます:

B={x1,x2,...,xm}μB=1mi=1mxiσB2=1mi=1m(xiμB)2x^=xiμσB2+ϵyi=γx^+β B = \{x_1, x_2, ..., x_m\} \\ \mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i \\ \sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2 \\ \hat{x} = \frac{x_i - \mu}{\sqrt{\sigma_B^2 + \epsilon}} \\ y_i = \gamma \hat{x} + \beta

正規化のステップはほぼ同一ですが、バッチに適用し、学習可能なパラメータγ\gammaβ\betaを用いた追加の線形関数を使用 します。最後の線形関数は、各バッチに対して平均を0に固定するのではなく、最適な平均を学習するためのものです。この直感的で単純な アプローチは、勾配を安定させながら学習速度を維持するのに非常に効果的であることが証明されており、前述の技術が不要にさえ なります。また、バッチ正規化はハイパーパラメータを持たず、活性化関数の選択にも依存しないため、最も堅牢な技術の1つと言えます。 以下にコードの実装を示します:

def batch_norm_forward(self, X, gamma, beta, running_mean, running_var, momentum=0.9, train=True):
    if train:
        batch_mean = np.mean(X, axis=0)
        batch_var = np.var(X, axis=0)
        X_normalized = (X - batch_mean) / np.sqrt(batch_var + self.epsilon)
        out = gamma * X_normalized + beta
        
        # Update running statistics
        running_mean = momentum * running_mean + (1 - momentum) * batch_mean
        running_var = momentum * running_var + (1 - momentum) * batch_var
    else:
        # Use running statistics for inference
        X_normalized = (X - running_mean) / np.sqrt(running_var + self.epsilon)
        out = gamma * X_normalized + beta
 
    return out, running_mean, running_var
 
def batch_norm_backward(self, dout, X, gamma, beta, mean, var):
    N = X.shape[0]
 
    X_normalized = (X - mean) / np.sqrt(var + self.epsilon)
    dgamma = np.sum(dout * X_normalized, axis=0)
    dbeta = np.sum(dout, axis=0)
 
    dX_normalized = dout * gamma
    dvar = np.sum(dX_normalized * (X - mean) * -0.5 * np.power(var + self.epsilon, -1.5), axis=0)
    dmean = np.sum(dX_normalized * -1 / np.sqrt(var + self.epsilon), axis=0) + dvar * np.sum(-2 * (X - mean), axis=0) / N
    dX = dX_normalized / np.sqrt(var + self.epsilon) + dvar * 2 * (X - mean) / N + dmean / N
 
    return dX, dgamma, dbeta
 
FNNClassifer.batch_norm_forward = batch_norm_forward
FNNClassifer.batch_norm_backward = batch_norm_backward

上記のコードは、バッチ正規化層の順伝播と逆伝播を表しています。逆伝播は比較的単純ですが、さらなる前の層への逆伝播 のために正規化前の入力に対する損失の導関数を計算する部分は例外的に複雑です。残念ながら、この計算はこの記事の範囲 を超えています(ベクトルの積の法則の同等物が関係します)。数学に興味がある場合は、自分で導関数を導くか、 Yeh, C.(2024年)による記事「Deriving Batch-Norm Backprop Equations」を参照することをお勧めします。

最終結果

適応化アルゴリズム、活性化関数、重みの初期化、L2正則化、およびバッチ正規化を実装した後、FNNClassifierはこの記事の 末尾の付録にあるコードのようになりました。これは、どのようなハイパーパラメータやデータセットにも非常に堅牢です。 ただし、各隠れ層の活性化関数、各隠れ層に使用する正則化および重みの初期化、およびバッチ正規化を使用する場所など、 一部カスタマイズできない部分があります。

これを改善するために、各層をLayerクラスとしてモジュール化し、層ごとの設定を行った上で、それらを組み合わせて一般的 なModelを作成し、選択したオプティマイザーとコスト関数を使用して逆伝播を行うことができます。人工ニューラルネットワーク のモジュール化およびモデルの各コンポーネントの最適化は、すでにTensorFlow、PyTorch、MXNetなどのいくつかの フレームワークに多様な方法で実装されています。したがって、この記事以降は、これらのライブラリを使用して ニューラルネットワークモデルを実装していきます。

リソース