MLエンジニアへの道 #34 - MPNNs & GTs

Last Edited: 1/15/2025

このブログ記事では、ディープラーニングにおけるMPNNsやGTs他について紹介します。

ML

これまでの2つの記事では、モデルの複雑さをGCNからGATへと高めてきました。この記事では、 さらに複雑なモデルであるメッセージパッシングニューラルネットワーク(MPNN)やグラフトランスフォーマーについて説明します。

メッセージパッシングニューラルネットワーク(MPNN)

これまで取り上げてきたモデルは、ノード特徴量XXを活用し、GATにおける重みWWa\overrightharpoon{a}を学習してきました。 しかし、グラフによっては、エッジ特徴量eeが存在する場合があります。このエッジ特徴量を活用することで、 モデルをより表現力豊かにすることが可能です。Gilmer, J.ら(2017)は、ノード特徴量とエッジ特徴量を利用して潜在ノード表現を生成する メッセージパッシングニューラルネットワーク(MPNN) を導入しました。 このモデルは、化学結合や空間距離をエンコードするエッジを含むグラフのタスクに対応するために設計されています。

mit+1=j{i,Ni}Mt(hit,hjt,eij)hit+1=Ut(hit,mit+1) m_i^{t+1} = \sum_{j \in \{i, N_i\}} M_t(h_i^t, h_j^t, e_{ij}) \\ h_i^{t+1} = U_t(h_i^t, m_i^{t+1})

上記はMPNNのメッセージパッシングフェーズを示しています。MtM_tはメッセージパッシング関数、UtU_tはアップデート関数であり、 これらはタスクに応じてカスタマイズ可能です。メッセージパッシング関数として、Gilmer, J.ら(2017)は、 エッジネットワークσ(Weij)hj\sigma(We_{ij})h_jや、シンプルなフィードフォワード層σ(W[hihjeij])\sigma(W[h_i || h_j || e_{ij}])を提案しています。 アップデート関数には、別のフィードフォワード層σ(W[himit+1])\sigma(W[h_i || m_i^{t+1}])を使用できます。 メッセージパッシングフェーズに加えて、オリジナルのMPNNにはリードアウトフェーズが含まれており、 潜在ノード表現に順列不変関数RRを適用して、グラフ全体の潜在表現を生成します。

コードの実装

以下は、MPNNレイヤーのTensorFlow実装例であり、MMおよびUUが結合特徴量に対して動作する単純なフィードフォワード層となっています。

class MPNNLayer(layers.Layer):
    def __init__(self, d_in, d_edge, d_latent, d_out):
        super(MPNNLayer, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_m = self.add_weight(shape=(2 * d_in + d_edge, d_latent),
                                    initializer='glorot_uniform',
                                    trainable=True)
        self.W_u = self.add_weight(shape=(d_in + d_latent, d_out),
                                    initializer='glorot_uniform',
                                    trainable=True)
 
    def call(self, x, A):
        # x: (batch_size, n, d), A: (batch_size, n, n, d_edge)
        # Message Passsing
        x_i = tf.tile(tf.expand_dims(x, 1), [1, tf.shape(x)[1], 1, 1])
        x_j = tf.transpose(x_i, [0, 2, 1, 3])
        x_ij = tf.concat([x_i, x_j, A], -1) # => (batch_size, n, n, 2d_in + d_edge)
        m = tf.matmul(x_ij, self.W_m) # => (batch_size, n, n, d_latent)
        m = tf.nn.relu(m)
        m = tf.reduce_sum(m, 2) # => (batch_size, n, d_latent)
        
        # Updating
        x = tf.concat([x, m], -1) # => (batch_size, n, d_in + d_latent)
        x = tf.matmul(x, self.W_u) # => (batch_size, n, d_out)
        x = tf.nn.relu(x)
        return x

練習として、さまざまなメッセージパッシング関数やアップデート関数をPyTorchで実装してみるとよいでしょう。

MPNNの複雑さ

MPNNの定義やコード実装を見ると、GCNやGATもMPNNの特定のインスタンスであることに気づくかもしれません。 例えば、以下はGCNのメッセージパッシング関数とアップデート関数を示しています。

M(hi,hj,eij)=j{Ni}1didjWhjUt(hit,mit+1)=1diWhi+mit+1 M(h_i, h_j, e_{ij}) = \sum_{j \in \{N_i\}} \frac{1}{\sqrt{d_id_j}} W h_j \\ U_t(h_i^t, m_i^{t+1}) = \frac{1}{d_i} W h_i + m_i^{t+1}

これらのモデルはエッジ特徴量を使用せず、ノード特徴量をノードペアに応じて異なる、 または同じスカラーで乗算するだけであるため、上記のMPNNよりも表現力が低いです。しかし、 より複雑なモデルが必ずしもより良い結果をもたらすわけではありません。 複雑なモデルは、スケーラビリティや学習可能性の問題に直面することが多いためです。 そのため、タスクに応じて適切な複雑さと表現力を持つモデルを選択することが重要です。

MPNNはどれほど強力か?

Xu, K. ら(2019)は、離散的なノード特徴(エッジ特徴なし)を用いたMPNNが、グラフを識別する際に1-WLテスト(この記事では扱いません)と同等の能力を持つことを証明しました。 そして、最もシンプルかつ強力なメッセージパッシング関数および更新関数を以下のように定義しました。これらはインジェクティブアグリゲータまたはサム(合計)を利用しています。

M(hi,hj,eij)=j{Ni}hjUt(hit,mit+1)=σ(W((1+ϵ)hit+mit+1)) M(h_i, h_j, e_{ij}) = \sum_{j \in \{N_i\}} h_j \\ U_t(h_i^t, m_i^{t+1}) = \sigma(W((1+\epsilon)h_i^t+m_i^{t+1}))

これらの関数を使用するモデルは Graph Isomorphism Network (GIN) と呼ばれ、ϵ\epsilonは定数または学習可能なスカラーです。 一方で、連続的なノード特徴の場合、Corso, G. ら(2020)は単一のアグリゲータではWLテストに匹敵する性能を持つことができないことを証明し、 一般的な目的で強力なアグリゲータの組み合わせとして、 Principal Neighborhood Aggregation (PNA) を提案しました。

=[IS(D,α=1)S(D,α=1)][μσmaxmin] \bigoplus = \begin{bmatrix} I\\ S(D, \alpha=1) \\ S(D, \alpha=-1) \\ \end{bmatrix} \otimes \begin{bmatrix} \mu \\ \sigma \\ max \\ min \\ \end{bmatrix}

ここで、SSは次数に依存するスケーラー、\otimesはテンソル積を表します。アグリゲータは、ニューラルネットワークの出力を集約するためにメッセージパッシング関数内で使用され、 その結果を更新関数である別のニューラルネットワークに渡すことができます。これらの操作の詳細については、記事末尾に挙げた原論文をぜひご確認ください。

過平滑化問題

最も強力なアグリゲータを使用した場合でも、過平滑化問題(ノードの潜在表現が一連の集約後にすべてのノードで同一に収束してしまう現象)は依然として発生します。 この問題に対処するためには、ネットワークを浅くする、層の複雑性を高める、スキップ接続を追加するなどの方法があります。 しかし、これらの手法はモデルのハイパーパラメータを慎重に選定する必要があり、モデルのスケーラビリティを制限するため、 代替アプローチの研究が進んでいます。

Graph Transformers

MPNNの過平滑化問題やスケーラビリティ、表現力の課題に対する代替アプローチとして、Graph Transformersが新たな研究分野として注目されています。 ノードおよびエッジ特徴をトークンとして扱い、適切な位置および構造エンコーディングを生成して、これらのトークンを変換器に渡すことで潜在ノードおよびエッジ表現を生成します。

TokenGT

Kim, J. ら(2022)は Tokenized Graph Transformer (TokenGT) を提案しました。 このモデルでは、ノード識別子として各ノードに対する直交正規ベクトルPPを利用し、トークンの構造情報をエンコードします。 ノードvvの場合、PvP_vをノード埋め込みに2回追加して、[XvPvPv][X_v || P_v || P_v]を作成します。 一方、エッジXuvX_{uv}の場合、PuP_uPvP_vをエッジ埋め込みに追加して、[XuvPuPv][X_{uv} || P_u || P_v]を作成します。 また、TokenGTでは、ノード用の学習可能な型ベクトルEVE^Vとエッジ用の型ベクトルEEE^Eも追加します。 直交正規ベクトルはランダムに生成するか、正規化グラフラプラシアン行列の固有分解により得ることができます (これはグラフの接続性やクラスタリングを反映する行列ですが、この記事では詳細は扱いません。今後の別の記事で解説する可能性があります)。

class GTTokenizer(layers.Layer):
    def __init__(self, n, d, latent):
        super(GTTokenizer, self).__init__()
        self.n = n
        self.d = d
        self.latent = latent
        # Learnable Type Identifiers
        self.e_v = self.add_weight(shape=(latent,),
                                    initializer='glorot_uniform',
                                    trainable=True)
        self.e_e = self.add_weight(shape=(latent,),
                                    initializer='glorot_uniform',
                                    trainable=True)
        # graph token
        self.graph_token = self.add_weight(
            shape=(1, 1, d+2*n+latent),
            initializer='glorot_uniform',
            name="graph_token",
        )
 
    def call(self, x, e, e_id, A):
        # x: (batch_size, n, d), e: (batch_size, m, d), e_id: (batch_size, m, 2), A: (batch_size, n, n)
        # Type Identifiers
        e_v = tf.tile(tf.expand_dims(tf.expand_dims(self.e_v, 0), 0), [tf.shape(x)[0], tf.shape(x)[1], 1])
        e_e = tf.tile(tf.expand_dims(tf.expand_dims(self.e_e, 0), 0), [tf.shape(e)[0], tf.shape(e)[1], 1])
 
        # Node Identifiers
        I = tf.eye(self.n)
        D = tf.reduce_sum(A, axis=-1)
        D = tf.linalg.diag(D)
        D_inv_sqrt = tf.linalg.inv(tf.sqrt(D))
        L = I - tf.matmul(tf.matmul(D_inv_sqrt, A), D_inv_sqrt) # normalized Laplacian
        _,v = tf.linalg.eigh(L) # Laplacian eigenvectors
 
        # For nodes
        x = tf.concat([x, v, v, e_v], -1)
 
        # For edges
        e_id_1, e_id_2 = tf.unstack(e_id, axis=-1)
        e_id_1 = tf.expand_dims(e_id_1, -1)
        e_id_2 = tf.expand_dims(e_id_2, -1)
        v_1 = tf.gather_nd(v, e_id_1, batch_dims=1)
        v_2 = tf.gather_nd(v, e_id_1, batch_dims=1)
        e = tf.concat([e, v_1, v_2, e_e], -1)
 
        # Token Generation
        graph_tokens = tf.repeat(self.graph_token, tf.shape(x)[0], axis=0)
        tokens = tf.concat([x, e, graph_tokens], 1) #=> (batch_size, (n+m+1), d+2*n+latent)
        return tokens

上記は、ノード識別子としてラプラシアン固有ベクトルを使用するTokenGTのトークナイザーをTensorFlowで実装したものです。 トークンはトランスフォーマーエンコーダーに渡され、その出力が下流タスクのモデルに利用されます。 Rampasek, L. ら(2022)はGraphGPS(General, Powerful, Scalableの略)を提案しました。 これは、さまざまな位置・構造エンコーディング(上記を含む)、MPNN、およびGraph Transformersをモジュール化して統合した、 スケーラブルかつ表現力豊かなアーキテクチャです。必要な基礎知識はこの記事でほぼカバーされているので、 興味があればぜひ原論文をご確認ください。

結論

この記事では、MPNNの定義、より複雑なMPNNのコード実装、離散および連続ノード特徴に対する最大限に強力なMPNN、 、過平滑化問題、およびGraph Transformersを取り上げました。スペクトルグラフ理論(グラフラプラシアン行列が導入される分野)、 WLテスト、さらにGIN、PNA、グラフトランスフォーマーについて今後議論するかもしれませんが、 この記事を通じて分野の概要をつかんでいただけたら幸いです。

リソース