MLエンジニアへの道 #47 - DVAEs

Last Edited: 3/13/2025

このブログ記事では、深層学習における幾つかの離散的なVAE(DVAE)を紹介しています。

ML

前回の記事では、パフォーマンス向上のために離散的な潜在状態表現を生成するVAEを使用するDreamerV2を紹介しました。 しかし、勾配を停止させずにそのようなVAEをどのように実装できるでしょうか?この記事では、 研究者たちが離散的な潜在表現を実現するために考案したいくつかの巧妙な手法と、離散化がパフォーマンスをどのように向上させる可能性があるかについて説明します。

注: VAEについてよく分かっていないという場合は、MLエンジニアへの道 #18 - VAEを事前にチェックすることをお勧めします。

ストレートスルー勾配

離散的な潜在表現を得る最も簡単な方法は、カテゴリカル分布からサンプリングすることです。 しかし、サンプリングの問題点は、それが微分可能ではなく、エンコーダーへの逆伝播を停止させることです。 この問題を回避するために、DreamerV2はストレートスルー勾配推定器と呼ばれる最も単純な解決策を使用しています。 これはサンプリングやその他の微分不可能な操作を無視し、真の勾配の推定値としてデコーダーの最初の層から直接勾配を逆伝播させます。

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dist=torch.distributions.categorical.Categorical(probs=x)
        return F.one_hot(dist.sample())
 
    @staticmethod
    def backward(ctx, g):
        return F.hardtanh(g)
 
class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
 
    def forward(self, x):
            x = STEFunction.apply(x)
            return x

ここで、DreamerV2のVAEは、カテゴリカル事後分布を隠れ状態から生成された事前分布に近づけるためにKLペナルティを適用します。 事前分布自体が学習可能であるため、DreamerV2はKLバランシングを使用し、事後分布の学習よりも事前分布の学習を速くするために高い学習率を適用します。 以下はDreamerV2のストレートスルー勾配推定器のPyTorchによる簡略化された実装です。 単純に勾配出力を渡すこともできますが、ここでは勾配を-1から1の間に再スケーリングして安定化するためにhardtanhを使用しています。

Gumbelソフトマックストリック

カテゴリカル分布からのサンプリングが微分不可能であるという問題に対する別の解決策は、ガウス分布に使用される再パラメータ化トリックと同様のトリックを適用することです。 ここでは、様々な異なる分布からのサンプルの最大値(または最小値)をモデル化するGumbel分布からノイズをサンプリングし、 そのノイズを対数確率に加え、対数確率とノイズの和のargmaxを取ることで、カテゴリカル分布からのサンプリングをシミュレートできます(対数関数は単調増加するため、対数確率を使用します)。 Bechtel(2022)によるこちらの動画では、確率密度が0の周りに集中するGumbel分布の形状について簡潔に説明しています。

class GumbelSoftmaxLayer(nn.Module):
    def __init__(self, temperature=10):
        super(CustomLayer, self).__init__()
        self.temperature = temperature # usually annealed
    
    def forward(self, x):
        # Sample Gumbel noise g ~ G(mu=0, beta=1)
        uniform = torch.rand(x.size())
        gumbel_noise = -torch.log(-torch.log(uniform))
 
        # Add to noise log prob
        log_prob = torch.log(x)
        log_prob += gumbel_noise
 
        # Use softmax when training and argmax or inference
        # assuming x of size (batch_size, num_latent, num_categories)
        if self.training:
            log_prob /= self.temperature
            return F.softmax(log_prob, dim=2)
        else
            return F.one_hot(torch.argmax(log_prob, dim=2)) 

Gumbel分布からサンプルノイズを追加することで、カテゴリが最大となる確率、つまりサンプルとして選ばれる確率を考慮に入れることができ、 効果的にカテゴリカル分布からのサンプリングをシミュレートします。この手法、Gumbelマックストリックは優れた再パラメータ化ですが、 それでも微分不可能なargmax操作があります。そこで、argmaxの代わりに温度パラメータτ\tauを持つソフトマックス関数を使用することができます(これまでは常にτ=1\tau = 1と仮定していました)。 τ\tauが0に近いほど、分布は鋭くなり、argmaxに近づきます。

τ\tauを高い値から0に徐々に調整(アニーリング)することで、最初は堅牢なトレーニングのために勾配を通過させ、 徐々に離散的に近い潜在表現に向かって作業することをモデルに学ばせることができます。 推論時には、離散的な潜在表現を生成するためにargmaxを使用できます。 トレーニング中にargmaxをソフトマックスと可変温度で置き換えるこのトリックは、Gumbelソフトマックストリックと呼ばれています。 上記はGumbelソフトマックストリックのPyTorch実装です。上記からわかるように、ノイズは0から1の間の一様分布からのサンプルに負の対数を2回適用することで、μ=0\mu = 0およびβ=1\beta = 1のGumbel分布からサンプリングできます。

VQ-VAE

もう一つの代替アプローチは、学習可能なコードブックまたは埋め込みを導入することです。コードブック内では、KK個の潜在ベクトルeie_iがあり、各ベクトルはDD個の要素を持ちます。 このアプローチでは、エンコーダーの出力をカテゴリカル分布ではなく最後の次元がDDのベクトルzez_eにして、ユークリッド距離に基づいて最も類似性の高い潜在ベクトルをコードブックから選択し、 それをデコーダーにzqz_qとして渡します。このプロセスはベクトル量子化(Vector Quantiation)と呼ばれ、このプロセスを通じて、選択されたeie_iのインデックスに対応するワンホットエンコードされたベクトルを、 離散的な潜在表現として得ることができます。

VQVAE

ここでは、ユークリッド距離に基づく決定論的な事後分布を使用し、(偏りのないトレーニングのため)トレーニング中の事前分布を離散一様分布と仮定します。 エンコーダ出力とコードブックが類似していると仮定でき、特にトレーニング後は両方に対してユークリッド距離を損失として使用するため、 デコーダーからエンコーダーへの勾配を自信を持って直接渡すことができます。

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost
 
    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous()

すべてのベクトルを選択する際一様な事前分布を仮定していますが、トレーニングデータを反映した各ベクトルを選択するための真の事前分布は一様ではなく、 順序依存性を持つと考えられます。(トレーニングデータに10個の数字がある場合、1つの潜在ベクトルを選択することは、デコーダーで生成する数字に影響を与え、 2番目の潜在ベクトルとその後のベクトルの選択分布に影響します。)

したがって、元の論文では、以前にサンプリングされたベクトルを条件とした真のカテゴリカル事前分布を出力する自己回帰モデル(RNNやTransformerなど)をトレーニングすることを提案しています(潜在表現のシーケンスにおける次のベクトル予測)。 このモデルにより、条件付き分布からシーケンスをサンプリングする祖先サンプリングを生成タスクに使用できます。 上記はVQ-VAEのベクトル量子化レイヤーのPyTorch実装例です。この実装では、関連するコンポーネントにアクセスできるモジュール内で量子化の損失を計算し、detachを使用して勾配を停止しています。

連続よりも離散を選ぶ理由

離散的な潜在表現を使用するための3つの技術を紹介しましたが、単純な連続的表現を使用できるのになぜ離散的表現を使用したいのかについては未だ議論していません。 離散表現を使用する明らかな利点はデータ圧縮です。連続表現では潜在次元ごとに1つの浮動小数点または32ビットを保存する必要がありますが、 離散表現では潜在次元ごとにlog2(K)log_2(K)ビット(KKはカテゴリの数)だけが必要で、これは多くの場合32ビットよりも小さいです(log2(1024)=10log_2(1024) = 10)。

このように離散的な潜在表現はより良いデータ圧縮を実現できますが、直感的には離散表現を使用すると潜在空間がKKに制限され、連続表現よりも性能が低下するように思えます。 しかし、DreamerV2は42のタスクでガウス潜在変数よりもカテゴリカル潜在変数でより良い結果を達成しています。元の論文では、 これはガウス潜在変数が:1.多峰性分布に適合できるカテゴリカル潜在変数とは異なり、単峰性分布を仮定しているため、 2.ストレートスルー推定器とは異なりϵ\epsilonによって不安定な勾配を悪化させる可能性があるため、 3.本質的に離散的な状態をモデル化するのに自然ではないためかもしれないと示唆しています。

2番目のポイントはGumbelソフトマックストリックには適用されませんが、 これら3つは表現力が限られていても離散潜在変数が一部のタスクで連続潜在変数よりも優れたパフォーマンスを発揮する理由として妥当な説明だと言えると私は考えます。 私たちは離散空間を持つ有限MDPや、本質的に離散的な自然言語やスピーチ(すでに量子化されトークン化されている)を扱うことが多いため、 それらを無理に連続空間にマッピングするよりも、離散的な潜在表現を使用する方が自然かもしれません。 さらに、VQ-VAEは下流タスクに適した対応する埋め込みを学習できるため、音声や画像などの連続データに対するトークン化のための自然なモデルとなり得ます(実際にトークン化のモデルとして利用されています)。

結論

この記事では、カテゴリカル潜在変数を生成できるストレートスルー勾配推定器、Gumbelソフトマックストリック、ベクトル量子化技術を紹介しました。 また、カテゴリカル潜在変数、離散的な潜在表現が場合によって非常に高いパフォーマンスを発揮し、好まれる理由についても議論しました。 これらの技術の詳細については、以下に引用されている元の論文や補足資料を確認することをお勧めします。

リソース