GANの原論文をきちんと読む(GAN解説)

深層学習の画像生成タスクにおいて現在の主流であるGANについてざっくりと理解していているが、元論文Generative Adversarial Netsをきちんと読む機会がなかったので今回読んでまとめた。

 

大まかにはGANは画像を生成する生成器G(Generative model) と生成された画像が本物なのか生成された画像なのか判断する識別器D(Discriminative model)を同時に訓練する。お互いに訓練を進めていくと識別機Dは本物の画像なのか生成器によって作られた偽物の画像なのかの判断精度が上昇していく。同時に識別器が偽物だと判断できないような画像を生成するように生成器Gも学習していく。例えとして偽札づくりの技術とそれを取り締まる技術のいたちごっごがよく挙げられる。

 

  • 敵対的ネット

GANの根幹的な技術はGANの文字が示すようにAdversarial Nets(敵対的ネット)である。先程述べた生成器Gと識別器DはどちらもMLP(Multiple Layer Perceptron)であることが前提条件としてあげられている。

論文中に出てくる変数の定義を以下に示す。

 

x:入力データ

pg:データxに対する生成器Gの分布

pz(z):入力ノイズ変数

G(z:θg):生成器。zに対するデータ空間への写像を表す、Gはパラメータθgを有するMLPによって表される微分可能関数

D(x;θd):識別器。出力として単スカラー値を返す。学習するときは本物の画像を1,偽物を0とラベル付けして学習する。

 

学習は以下の関数V(G,D)にミニマックス法を用いて行う。

f:id:kuzika:20190306175530p:plain

ちなみにX〜p_data(x)はxが確率分布関数p_data(x)に従って分布していることを表している。

まずはGを固定しV(G,D)の値を最大化するときを考える。第一項は生成器Gは関与していないためDに依存し、Dは訓練用の入力データであるxに依存する。logD(x)を最大化するにはD(x)を最大化、つまり本物の画像であることを表すラベル1になるようにする。第二項では生成画像G(z)の識別を偽物のラベルを表す0に近づけることでlog(1-D(G(z)))は最大化される。つまり第一項では本物が本物かどうか、第二項では偽物が偽物かどうか判断する役割を持つ。

 

次にDを固定し、V(G,D)の値を最小化することを考える。

第一項ではDが固定値なので無視する。第二項ではGが生成した画像を本物だと誤った判断をされるように学習すれば 

log(1-D(G(z)))は最小化される。

 

上記のイメージを表したのが下図である。

f:id:kuzika:20190306181953p:plain

黒点線:訓練データxの分布px

緑実線:生成器Gの分布pg(G)

青破線:識別器Dの分布

 (a)の図ではデータxと生成器Gによって生成された画像データの分布がずれていることがわかる。識別器は部分的に正確に分類できていることがわかる。(識別器が画像が本物らしいと判断したら、その本物度合いに応じて青破線の縦軸の値が大きくなっている)

(b)の図では識別器が最適化されている。このときの識別器Dを更新する計算式はD^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}で表される。つまり本物偽物2つの分布が完全にかぶるとDの値は1/2になる。見分けがつかないのでこの値は妥当である。

(c)の図では緑実線のGの分布が更新され、黒点線xの分布に近づいている。

(d)では以上の操作が何回か繰り返され、理想的には最終的に図のように完全に分布が一致する。識別器も見分けがつかないので1/2の一様分布になる。

 

全体のアルゴリズムの流れとしては、k回識別器Dを学習した後に1回だけ生成器Gを学習するというループを繰り返している。論文でもあまりGを訓練する回数を増やすと良い結果は出ないといっている。

 

次にこれまでに示したアルゴリズムと式を再構築してまとめている。

 

f:id:kuzika:20190307003630p:plain

このように先程の識別器の更新の式を最初のminimaxの式に代入している。

 

この再構築したC(G)の最小値は、p_dataとpgが一致したときで-log4を取ることを示している。

 

これは代入してみれば一目瞭然で

logの中身がどちらも1/2になるのでlog(1/2)=-log2なので2つを足して-log4になる。

この理論上の最低値をC(G)の式から引くとKLダイバージェンスを用いて次のように表せる。

 

f:id:kuzika:20190307004556p:plain

KLダイバージェンスはこの式の第二項を例にあげれば、pdataと(pdata+px)/2の一致度を示すもので、完全に値が一致している場合0になる。

 

この2つをJSダイバージェンスによって式を更に簡略化している。

f:id:kuzika:20190307004838p:plain

 

いろいろ理解が間違っているところもあるかもしれませんがその場合はコメントで教えてくれると幸いです。

 

次回あたりはGANの実装をまとめたらいいな。