このブログ記事では、ディープラーニングにおけるVAEについて紹介します。

変分オートエンコーダ(VAE)
前回の記事では、画像が潜在空間にどのように配置されているかが未知であることと、そしてそれが生成目的でベクトルを 選択しデコーダーに渡すことを妨げていることについて説明しました。この問題に対する解決策の一つは、 潜在空間を標準正規分布 () に圧縮することで、標準正規分布から任意のベクトルをサンプリングして画像を生成できるようにすることです。 このアプローチは 変分オートエンコーダ(VAE) と呼ばれ、以下の図のような構造を持ちます。

VAEでは、エンコーダが画像が属する潜在空間内の正規分布の平均と分散を出力します。 エンコードされた正規分布からランダムなサンプルが選ばれ、デコーダがそのサンプルを元に画像を再構築します。 逆伝播を使用して、デコーダが標準正規分布からの任意のサンプルから画像を生成できるように強化され、 同時にエンコーダが画像を標準正規分布内の正規分布にエンコードできるように訓練されます。 ここでは、目標分布(標準正規分布)は事前分布(prior)と呼ばれ、エンコーダによって生成される分布は事後分布(posterior)と呼ばれます。
再パラメータ化トリック
このアプローチは概念的には理解できますが、サンプリングのプロセスの勾配を計算することは実際には不可能です。 そのため、勾配が計算可能な式を使ってサンプリングのプロセスをシミュレートする方法が必要です。 それが 再パラメータ化トリック です。実際に正規分布からサンプルを選ぶ代わりに、次の式を使用します。
ここで、 はサンプル、 と はエンコーダが生成する平均と標準偏差、 は標準正規分布からのランダムサンプル () です。 の代わりに上記の関数 を使用することで、 "サンプリング" の勾配 と が と であることを決定できます。 これにより、損失に対する勾配がエンコーダに逆伝播されるようになります。
ログ分散トリック
エンコーダは現在、正規分布の分散を出力するように設定されていますが、分散は正の値である必要があります。 エンコーダが負の値も生成できるようにするために、分散の代わりに分散の対数を出力させます。これにより、 これまで行っていたように単に分散の平方根を取って標準偏差を得ることはできません。対数分散から標準偏差を得るには、 次のように解きます。
これにより、対数分散の式を使って で標準偏差を取得できることがわかります。 したがって、関数 を次のように書き換えることができます。
この分散の対数を使用するトリックを ログ分散トリック と呼びます。
損失関数
通常のオートエンコーダとは異なり、事後分布が事前分布にできるだけ近づくことを望みます。 そのためには、MSE損失とは別に、事前分布から遠く離れた事後分布に対してペナルティを課す追加の損失が必要です。 この記事MLエンジニアへの道 #2 - ロジスティック回帰の前提知識を思い出してみると、 2つの分布間の距離を測定するための1つのメトリックとして、KLダイバージェンスを使用できることがわかります。
ここで、 はデコーダによる画像の再構成に対するMSE損失、 は と との間のKLダイバージェンスです。 は、MSEとKLの相対的な重要性を指定するハイパーパラメータです。 と の間のKLダイバージェンスは、 として導き出せます。これは、 と がそれぞれ0と1のときにこの値が最小(ゼロ)になるため、理にかなっています。 ( が1のとき、対数分散が0になり、左側の1が打ち消されます。導出の詳細に興味がある場合は、このページを参照してください。)
コードの実装
VAEを構築するためのすべての要素が揃ったので、実際にコードで実装してみましょう。MNISTデータセットを使用し、 前回の記事でデータの準備と前処理のステップ1と2はすでに行ったので省略します。
ステップ3. モデル
以下は、PyTorchとTensorFlowを使ったVAEの実装例です。
ステップ 4. モデル評価
トレーニング後、VAEがどれだけうまく画像を再構築できるかを確認するため、テストデータセットと予測されたテストデータをプロットしてみましょう。 そのためには、VAEで予測を行い、テストデータと予測されたテストデータの両方を適切な形にする必要があります。
# TensorFlow
preds = vae.predict(X_test)
_, preds = tf.split(preds, [20, 784], 1)
preds = preds.numpy()
X_test = X_test.reshape(X_test.shape[0], 28, 28)
preds = preds.reshape(preds.shape[0], 28, 28)
# PyTorch
for X, y in test_loader:
Xs = X
_, _, preds = vae(X)
Xs = Xs.numpy()
preds = preds.detach().numpy()
Xs = Xs.reshape(Xs.shape[0], 28, 28)
preds = preds.reshape(preds.shape[0], 28, 28)
次に、以下の関数を使用して、それぞれの画像から10枚をプロットできます。
def plotImgs (X):
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(X[i], cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
plotImgs(X_test)
plotImgs(preds)
以下は、PyTorchで実装されたVAEの結果です。

VAEが画像をうまくエンコードし、デコードできるようになったことがわかります。では、標準正規分布からランダムにサンプルを取って、デコーダーが画像を生成できるかどうかを見てみましょう。
# TensorFlow
latent = np.random.normal(0, 1, (10, 10))
decoded = decoder.predict(latent)
decoded = decoded.reshape(decoded.shape[0], 28, 28)
# PyTorch
latent = torch.randn(10, 10)
decoded = vae.decoder(latent)
decoded = decoded.detach().numpy()
decoded = decoded.reshape(decoded.shape[0], 28, 28)
plotImgs(decoded)
以下は、TensorFlowで実装されたVAEデコーダーの結果です。

画像はまだ少しぼやけていますが、手書きの数字を認識することができ、前回の記事で通常のオートエンコーダーのデコーダーを使用して生成した画像とは異なることがわかります。
結論
エンコーダーが標準正規分布(事前分布)の正規分布(事後分布)を出力することで、オートエンコーダーを生成目的で使用できる変分オートエンコーダー(VAE)に変換することができました。 しかし、実際にモデルを訓練してみた方は、データセットが小さいにも関わらず、訓練にかなりの時間がかかり、不安定になることがあると気づいたかもしれません。 そのため、モデルにいくつかの改善が必要です。(ヒントとして、VAEは通常のDense層に限定されるわけではありません。)
さらに、生成目的で使用できるのはVAEだけではありません。生成モデルを設計するには、もっと様々な方法もありますので、 他の方法も考えてみることをお勧めします。(次の記事でそのうちの一つを紹介します。)
リソース
- Raschka, S. 2021. L17.3 The Log-Var Trick. YouTube.
- Raschka, S. 2021. L17.4 Variational Autoencoder Loss Function. YouTube.
- Raschka, S. 2021. L17.5 A Variational Autoencoder for Handwritten Digits in PyTorch -- Code Example. YouTube.