MLエンジニアへの道 #33 - GAT

Last Edited: 1/10/2025

このブログ記事では、ディープラーニングにおけるGATについて紹介します。

ML

前回の記事では、グラフ畳み込みがどのように動作するか、そしてホモフィラスグラフに対してどれほどシンプルで効果的かを説明しました。 今回の記事では、より複雑なモデルであり、非ホモフィラスグラフにも対応可能な グラフアテンションネットワーク について解説します。

グラフアテンション

グラフ畳み込みのシンプルさは、各ノードペアに対して同じ重みを適用することに基づいています。これは、置換同変性を維持するためには必要不可欠でした。 しかし、グラフ畳み込みで変更できる要素が1つあります。それは正規化因子です。GCNでは、正規化に D1D^{-1}(または D12D^{\frac{1}{2}})を使用しますが、 これらの値は一定です。代わりに、学習可能な重み、アテンションを各ノードペアに割り当てることで、より複雑な関係を学習できるモデルを構築できます。

eij=LeakyReLU(aT[WhiWhj])αij=softmaxj(eij)=exp(eij)kNiexp(eik) e_{ij} = \text{LeakyReLU}(\overrightharpoon{a}^T[Wh_i \Vert Wh_j]) \\ \alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\text{exp}(e_{ij})}{\sum_{k \in N_i}\text{exp}(e_{ik})}

ノードペア eije_{ij} に対するアテンション係数の計算は、WhiWh_iWhjWh_j を連結(\Vert)し、 それを重みベクトル a\overrightharpoon{a} と LeakyReLU 活性化関数を持つフィードフォワード層に通すことで実現します。 ノード ii のすべての隣接ノード(ii 自身も含む)の係数はソフトマックス関数に通され、アテンション αij\alpha_{ij} の値が 0 から 1 の範囲内に調整されます。 以下は、この計算に対応するテンソル操作を(無理矢理)表した表記です。

E=LeakyReLU([[HW]×n([HW]×n)T]a)A=softmax(A~E) E = \text{LeakyReLU}([[HW]_{\times n} \Vert ([HW]_{\times n})^T]\overrightharpoon{a}) \\ \Alpha = \text{softmax}(\tilde{A}E)

HW×nHW_{\times n}HWHWnn 回繰り返した 3 次元テンソルであり、そのサイズは (n,n,dout)(n, n, d_{out}) です。 このテンソルとその転置(第 1 軸と第 2 軸が入れ替わり、サイズ (n,n,dout)(n, n, d_{out}) のテンソルになる)は、第 3 軸に沿って連結され、 サイズ (n,n,2dout)(n, n, 2d_{out}) のテンソルになります。このテンソルにサイズ (2dout,1)(2d_{out}, 1)a\overrightharpoon{a} を掛け、 LeakyReLU を通すことで、アテンション係数行列 EE(サイズ (n,n)(n, n))を計算します(第 3 軸を平坦化した後)。

最後に、EEA~(=A+I)\tilde{A} (=A+I) を掛けてマスクを適用し、ソフトマックス関数に通すことで、 マスクされたアテンション行列 A\Alpha(サイズ (n,n)(n, n))を得ます。この注意の計算は、 行と列の順序を維持するため、置換同変性を保ちます。

グラフアテンションネットワーク

前のセクションで、置換同変性を保ちながら異なるノードペア間で値が異なる学習可能なグラフアテンションを計算する方法を設定しました。 グラフアテンションネットワーク(GAT) は、GCN における D1D^{-1} を得られた A\Alpha に置き換えることで、 非ホモフィラスグラフにおいても下流タスクの適切な埋め込みを生成できるようにします。

hi=g(xi,XNi)=σ(j{i,Ni}αijWxj)H=F(X,A)=[g(x1,XN1)g(x2,XN2)g(xv,XNv)]=σ(AA~XW) h_i = g(x_i, X_{N_i}) = \sigma(\sum_{j \in \{i, N_i\}} \alpha_{ij}Wx_j) \\ H = F(X, A) = \begin{bmatrix} - g(x_1, X_{N_1}) -\\ - g(x_2, X_{N_2}) -\\ \vdots \\ - g(x_v, X_{N_v}) -\\ \end{bmatrix} = \sigma(\Alpha\tilde{A}XW)

さらに、GAT はトランスフォーマーと同様にマルチヘッドアテンションを利用して、より複雑な関係を持つグラフに対応します。 各ヘッドはそれぞれ独自の aaww を持ちます。すべてのヘッドから得られる出力の潜在埋め込みは、最後の層を除き連結されます。 最後の層では、埋め込みの平均が取られます。

hi=k=1Kσ(j{i,Ni}αijkWkxj)hi=σ(1Kk=1Kj{i,Ni}αijkWkxj) h_i = \Big\Vert_{k=1}^{K}\sigma(\sum_{j \in \{i, N_i\}} \alpha_{ij}^kW^kx_j) \\ h_i = \sigma(\frac{1}{K}\sum_{k=1}^K\sum_{j \in \{i, N_i\}} \alpha_{ij}^kW^kx_j)

この計算がトランスフォーマーのアテンションに非常に似ていることに気付くかもしれません。主な違いは、 グラフアテンションでのアテンション係数の計算において連結とフィードフォワード層を使用する点であり、 トランスフォーマーでは正規化された内積を使用します。(隣接行列の使用はグラフ注意に特有のものですが、 トランスフォーマーはトークンの完全グラフを仮定しているとも解釈できます。)

コードの実装

まず、シングルヘッドアテンションを設定し、その後でマルチヘッドアテンションに組み合わせることができます。以下は、シングルヘッドアテンションの TensorFlow 実装です。

class GraphAttention(layers.Layer):
    def __init__(self, d_in, d_out):
        super(GraphAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W = self.add_weight(shape=(d_in, d_out),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.a = self.add_weight(shape=(2*d_out, 1),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.leaky_relu = layers.LeakyReLU(0.2)
 
    def call(self, x, A):
        # x: (batch_size, n, d_in), A: (batch_size, n, n)
        x = tf.matmul(x, self.W)
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j], -1)
 
        e = tf.squeeze(tf.matmul(x_ij, self.a), [-1])
        e = self.leaky_relu(e)
 
        attn = tf.nn.softmax(tf.matmul(A, e))
        x = tf.matmul(A, x)
        x = tf.matmul(attn, x)
        return x

フィードフォワード層の重み行列やベクトルを使用する代わりに、適切なフィルター数を持つ Conv1D を適用して同じ効果を得ることができます(これはオリジナル論文の著者も採用しています)。 具体的には、a\overrightharpoon{a}[a1a2][a_1||a_2] に分割し、aT[WhiWhj]\overrightharpoon{a}^T[Wh_i || Wh_j]a1TWhi+a2TWhja_1^TWh_i + a_2^TWh_j に変換できます。 これは、カーネルサイズが 1 の Conv1D レイヤーを用いて実装できます。

さらに、入力やアテンション係数に対するドロップアウトレイヤーや残差接続を組み込むことで、モデルをより堅牢にすることができます。 非線形性 σ\sigma はシングルヘッドアテンションには適用されません。これは、マルチヘッドアテンションレイヤーで適用されるためです。

class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, num_heads, d_in, d_out, predict=False):
        super(MultiHeadGraphAttention, self).__init__()
        self.num_heads = num_heads
        self.attention_heads = []
        self.predict = predict
        for _ in range(num_heads):
            self.attention_heads.append(GraphAttention(d_in, d_out))
 
    def call(self, x, A):
        head_outputs = [head(x, A) for head in self.attention_heads]
        if self.predict:
          head_outputs = tf.reduce_mean(head_outputs, axis=0)
        else:
          head_outputs = tf.concat(head_outputs, [-1])
        return tf.nn.relu(head_outputs)

上は、TensorFlow によるマルチヘッドグラフアテンションの実装例です。ここでは、活性化関数 σ\sigma として ReLU を使用しています。 他の活性化関数やパラメータを試してみたり、PyTorch を用いて実装してみたりすることをおすすめします。このマルチヘッドアテンションをスタックすることで、 下流タスク用の潜在埋め込みを生成する GAT を構築できます。

静的アテンションと動的アテンション

上記セクションの注意計算を見て何か気づいた場合は、この記事を読むのをやめて、すぐに機械学習研究者を目指した方が良いかもしれません。 上記のGATは 静的アテンション を計算しており、高い注意から低い注意へのノードの順序がクエリノードによって変わることはありません。 これは、重みa\overrightharpoon{a}WWが1つの線形操作にまとめられるためであり、さらにLeakyReLUとsoftmaxが単調関数であるためです。 そのため、アテンションはhjh_jに対して単調になります。アテンションの値はhih_iによって変化することもありますが、 注意の順序はすべてのクエリノードで同じままであり、鋭さだけが変わります。

Static Attention

Brody, S. ら(2022)は、この問題を静的アテンションをよく説明する上記の図で指摘しました。この図は、 完全二部グラフで構成された9つのクエリノードとキーのノードを用いてGATで計算した注意を示しています。 すべてのクエリノードにおいて、線の形が同じであり、k8が最も高い注意値を持つことがわかります。 本来であれば、 動的アテンション を求めるべきです。動的アテンションでは、アテンションの順序がクエリノードによって変わり、 より複雑な関係を考慮できます。(ただし、単純なグラフでは、動的アテンションよりも静的アテンションのほうが実際にはうまく機能する場合もあります。)

グラフアテンションネットワークv2

最大の問題は、重みa\overrightharpoon{a}WWがまとめられる点であり、これが非線形関係の学習を妨げています。 Brody, S. ら(2022)は、以下のように操作の順序を変更してこの問題を解決しました。

eij=aTLeakyReLU(W[hihj])αij=softmaxj(eij)=exp(eij)kNiexp(eik) e_{ij} = \overrightharpoon{a}^T\text{LeakyReLU}(W[h_i \Vert h_j]) \\ \alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\text{exp}(e_{ij})}{\sum_{k \in N_i}\text{exp}(e_{ik})}

改良された方法では、最初にhih_ihjh_jを連結し、WWを適用し、その後に非線形性を適用してからaT\overrightharpoon{a}^Tを掛けます。 この変更により、重みがまとめられず、非線形関係を学習することが可能になります。Brody, S. ら(2022)は、この変更により動的アテンションを達成できることを確認しました。

class GraphAttentionV2(layers.Layer):
    def __init__(self, d_in, d_out):
        super(GraphAttentionV2, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W = self.add_weight(shape=(d_in, d_out),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.a = self.add_weight(shape=(d_out, 1),
                                       initializer='glorot_uniform',
                                       trainable=True)
        self.leaky_relu = layers.LeakyReLU(0.2)
 
    def call(self, x, A):
        # x: (batch_size, n, d), A: (batch_size, n, n)
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j], -1)
        w = tf.concat([self.W, self.W], 0)
        x_ij = tf.matmul(x_ij, w)
        x_ij = self.leaky_relu(x_ij)
        
        e = tf.squeeze(tf.matmul(x_ij, self.a), [-1])
 
        attn = tf.nn.softmax(tf.matmul(A, e))
        x = tf.matmul(A, x)
        x = tf.matmul(attn, x)
        x = tf.matmul(x, self.W)
        return x

上記はGATv2で使用されるシングルヘッドグラフアテンションのコード実装です。ご覧のとおり、元のGATとの唯一の違いは操作の順序、Wの積み重ね、 およびaの次元の変更です。マルチヘッドアテンションのコードは、シングルヘッドアテンションを置き換えることを除いて変更されません。

結論

この記事では、GCNの正規化因子をアテンションに置き換え、シンプルでスケーラブルでありながら非ホモフィラスグラフにも対応できるGATを作成しました。 また、元のGATが静的アテンションを使用していることを発見し、動的アテンションを達成するGATv2を紹介しました。これにより、さらに複雑な関係にも対応できるようになりました。 次回の記事では、さらに複雑で表現力豊かなグラフモデルについて見ていきます。

リソース