MLエンジニアへの道 #19 - GAN

Last Edited: 9/14/2024

このブログ記事では、ディープラーニングにおけるGANについて紹介します。

ML

Generative Adversarial Networks

前回の記事では、オートエンコーダを修正して画像を生成できるようにした変分オートエンコーダ (VAE) について説明しました。 今回は、興味深い代替アプローチである生成的敵対ネットワーク (GANs) についてお話しします。その名前が示すように、 GANは以下のように敵対者(敵)を使って訓練される生成モデルです。

GAN

生成器(ジェネレーター)は、識別器(ディスクリミネーター)モデルを欺くほどリアルな画像を生成するように訓練されます。 一方、識別器は生成器が生成した画像かどうかを区別できるように訓練されます。トレーニングを通じてお互いに競い合った結果、 最終的に生成器は、人間の目で見ても本物と区別できないほどリアルな画像を生成することが期待されます。

コードの実装

GANsの利点の一つは、アーキテクチャが直感的で理解しやすく、数学的な知識があまり必要ないことです。そのため、 コードの実装にすぐ取り掛かることができます。また、データの前処理もあまり必要ありません。正規化された画像を用意すれば、 生成される画像や0と1でラベル付けされたテンソルは、その場で生成できるためです。 (MNISTを引き続き使用するのでステップ1と2は省略します。)

ステップ 3. モデル

以下は、PyTorchとTensorFlowを使ったGANの実装例です。

ステップ 4. モデル評価

モデルをトレーニングした後、GANから生成器を取り出し、適切なサイズのノイズを入力して新しい画像を生成できます。 PyTorch で実装された GAN を50エポック訓練した後、どのような画像が生成されるか見てみましょう。

# Generate
latent = torch.randn(10, 16)
generated = gan.generator(latent)
 
# Preprocess
generated = generated.detach().numpy()
generated = generated.reshape(generated.shape[0], 28, 28)
 
plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.imshow(generated[i], cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.show()
ML

すでに、モデルが手書きの数字に似た線を描くことを学び始めているのがわかります。GANの真の可能性を理解するために、 ぜひ自分でモデルをよりトレーニングしてみることをお勧めします。元のGAN論文によれば、 GANはサンプリングなしでVAEよりも鮮明な画像を生成できる一方で、敵対するモデルを同期して訓練するのは一般的に難しいとされています。

結論

この2つの記事で、VAEとGANという2つの生成モデルについて説明し、それぞれの課題に触れました。共通の問題として浮かび上がったのが、 非常に小さな画像の小さなデータセットでトレーニングしているのに、訓練が非常に遅くなるという点です。 もし1024x1024ピクセルのリアルなRGB画像を使用するとなると、これらのモデルを訓練するのはほぼ不可能になります。 したがって、これらの大きな画像をより効率的に処理できる異なる層を探す必要があります。その詳細は次回の記事で取り上げます。

リソース