2017/11/26

Coulomb GAN

25個のガウス分布で生成した2次元の点の分布の学習が普通のGANでは困難だったので、Coulomb GAN [1]を使ってみました。

学習したい分布

Geometric GAN [2]の実験に使われている混合ガウス分布を使います。 25個のガウス分布を縦横それぞれに-21, -10.5, 0, 10.5, 21の位置の組合せで配置します。分散は0.1です。 下図のような分布です。

25000点サンプリングして描画しています。 点が集まっている部分の周りの楕円は、25個の混合ガウスモデル(GMM)でフィットさせることで求めています。

この分布に従ったサンプルを生成する生成器\(G\)を学習します。

普通のGAN

まずは、普通のGANで試します。生成器\(G\)のネットワークはKerasのコードで示すと、

def generator_model_elu():
    model = Sequential()
    model.add(Dense(128, input_dim=4))
    model.add(ELU())
    model.add(Dense(128))
    model.add(ELU())
    model.add(Dense(128))
    model.add(ELU())
    model.add(Dense(2))
    return model
です。4つの乱数を入力すると、2次元空間での点の座標を出力します。識別器\(D\)は
def discriminator_model_elu():
    model = Sequential()
    model.add(Dense(128, input_dim=2))
    model.add(ELU())
    model.add(Dense(128))
    model.add(ELU())
    model.add(Dense(128))
    model.add(ELU())
    model.add(Dense(1))
    return model
です。2次元空間での点の座標を入力として、真の分布から生成された点なら1、\(G\)が生成した点なら0を出力します。

どちらのオプティマイザにもAdamを使います。また、Lossはbinary_crossentropyで計算します。学習させた結果が以下です。 失敗集です。網羅的に試しているわけではありません。

ParameterValueImages
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
512
0.5
0.0001
0.001
0.5
0.0001
0.001
129エポックまで実行しましたが、同じ状況で変わりません。下の図は13エポック目の図です。
同パラメータで再実行してみましたが、下の図のように形状は変わるものの進みません。
さらに、同パラメータで再実行してみましたが、下の図のように形状は変わるもののやはり進みません。
さらに、同パラメータで再実行してみましたが、下の図のように形状は変わるものの進みません。
さらに、同パラメータで再実行してみましたが、下の図のように形状は変わるものの進みません。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
128
0.5
0.000025
0.001
0.5
0.000025
0.001
広がったものの、形状に変化がみられないので中断。中断したエポックでの結果が以下です。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
128
0.5
0.001
0.0005
0.5
0.001
0.0005
下の図は5エポック目のものですが、296エポックまで実行しても状況は変わりません。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
128
0.5
0.000025
0.001
0.5
0.000025
0.001
変化しないので、中断。中断したエポックでの結果が以下です。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
512
0.5
0.00001
0.01
0.5
0.00001
0.01
これも変化しないので、中断。中断したエポックでの結果が以下です。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
512
0.5
0.00001
0.0001
0.5
0.00001
0.0001
20エポック目くらいまでは分布が広がるものの、
最終的には、中心に集まってしまいます。残念。

全く同じパラメータで再実行してみましたが、おかしな所に点が集まり学習が進まなくなりました。

batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
512
0.5
0.000001
0.00001
0.5
0.000001
0.00001
棒がくるくる回っているのですが、時間がかかりそうなので中断。中断したエポックでの結果が以下です。
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
512
0.1
0.00001
0.00001
0.5
0.00001
0.0002
途中までは比較的順調に進んだ結果が以下です。途中までは順調でしたが、途中から\(G\)が踊り始めます。クリックするとAPNGによるアニメーションが表示されます(32MB)。999エポックまで実行しました。
discは識別器のAdamのパラメータ、genは生成器のAdamのパラメータを表します。

この他にも色々な値で試しましたが、残念ながら再現したい25個のガウス分布っぽいものに行き着くことができませんでした。

Coulomb GAN

というわけで、Coulomb GAN [1]です。 各点を電荷に見立てることで、間違った分布を生成器が学習することを防ぎます。 真の分布から得られた点を正の電荷(+)、生成器から得られた点を負の電荷(-)として、 負の電荷に対するポテンシャルを考えると、下図のようになります。

電荷(サンプル点)の位置 マイナス電荷に対するポテンシャル

このポテンシャルをそのままLossとして生成器を学習します。 そうすると、ポテンシャルの低いほうへ負の電荷が移動する方向に生成器が表す分布が変形していきます。 言い換えると、負の電荷、すなわち、生成器が生成した偽の点\(\bm{a}^{\rm fake}\)同士は離れる方向に、かつ、真の分布から得られた真の点\(\bm{a}^{\rm real}\)には近づくように学習することになります。識別器\(D\)はCoulomb GANにおいては偽の点と真の点を識別するように学習するわけではなく、代わりにポテンシャルの分布を学習していきます。したがって、\(D\)のLossは \[ \mathcal{L}_D \propto \sum_i \left( D(\bm{a}_i) - \Phi(\bm{a}_i)\right)^2 \] となります。\(\Phi(\bm{a}_i)\)は点\(i\)の位置\(\bm{a}_i\)でのポテンシャルです。\(\bm{a}_i\)は偽の点と真の点の両方の位置を表しますが、計算するポテンシャルは偽の点としてのポテンシャルです。ポテンシャルの分布を直接学習できるわけではありませんが、サンプル点におけるポテンシャルの値を模擬できるように\(D\)を学習することで任意の点のポテンシャルの値を予測できるようになります。このようにポテンシャルの形状に近い形を学習した\(D\)を使うと、任意の点の勾配を計算できるようになります。

\(D\)を使うと、生成器\(G\)のLossを \[ \mathcal{L}_G \propto \sum_i D(G(z_i)) \] と計算できます。ポテンシャルが低くなる方向に偽の点が生成されやすくなるように\(G\)が学習されます。

真の点が作り出す偽の点に対するポテンシャルは次の式で計算します。 \[ \phi(r) = \frac{-1}{(r^2+\varepsilon^2)^{(d-2)/2}} \] 偽の点が作り出す偽の点に対するポテンシャルは符号が逆になるだけです。ここで、\(r\)は、ある点\(a_i\)からのユークリッド距離、\(d\)は次元数で、今回は3のみを使います。\(\varepsilon\)はソフトニングのためのパラメータです(Plummer radiusとも言う)。グラフの形状は下図のようになります。

Gnuplot Produced by GNUPLOT 5.0 patchlevel 7 -4 -3.5 -3 -2.5 -2 -1.5 -1 -0.5 0 -30 -20 -10 0 10 20 30 Potential x ε=0.25 ε=0.25 ε=0.5 ε=0.5 ε=1 ε=1 ε=2 ε=2 ε=4 ε=4

バッチサイズを\(N_B\)、真の点を\(\bm{a}^{\rm real}_i\)、偽の点を\(\bm{a}^{\rm fake}_i\)とすると、あるミニバッチで計算される、位置\(\bm{x}\)でのポテンシャルは \[ \Phi(\bm{x}) = \sum_i^{N_B} \phi(|\bm{a}_i^{\rm real}-\bm{x}|) - \sum_i^{N_B} \phi(|\bm{a}_i^{\rm fake}-\bm{x}|)\] となります。

[1]では、\(\varepsilon\)を指数関数的に減少させているようなので、それに倣って減少させるようにしました。減少する速さを想像しやすいように、半減期(Half life)の式で計算することとしました。単位はイテレーションです。半減期が5000なら、5000イテレーション後の\(\varepsilon\)が半分になります。

また、上記の\(\Phi(\bm{x})\)の定義では、ポテンシャルを発生させている点自身のポテンシャル(\(r=0\)のときの値)を計算してしまいます。その結果、\(\varepsilon\)が小さくなるにつれ、各点でのポテンシャルの絶対値が \[ \phi(t) = \frac{-1}{((\varepsilon_0 \times 2^{\frac{-t}{t_{\rm hl}}})^2)^{(d-2)/2}} \] となります。d=3とすれば、 \[ \phi(t) = \frac{-1}{\varepsilon_0} \times 2^{\frac{t}{t_{\rm hl}}} \] です。半減期\(t_{\rm hl}\)=1000、\(\varepsilon\)の初期値\(\varepsilon_0\)=1、1エポックが200イテレーション、100エポック処理すると仮定すると、\(\phi(20000)\sim -1\times 10^6\)となります。\(D\)がポテンシャルの形状を全く追随できていないとすると、MSEは\(\sim 1\times 10^{12}\)になります。floatの上限は\(3.402823\times 10^{+38}\)なので、まだ余裕がありますが、最終的には、\(D\)のLossはinfやNaNになります。

これを避けるため、以下の実験では、自分自身のポテンシャルを計算から除くようにしました。つまり、 \[ \Phi(\bm{x}) = \sum_{i,a_i^{\rm real} \neq x}^{N_B} \phi(|\bm{a}_i^{\rm real}-\bm{x}|) - \sum_{i,a_i^{\rm fake} \neq x}^{N_B} \phi(|\bm{a}_i^{\rm fake}-\bm{x}|) \] で計算します。

実験 (バッチサイズ512)

それでは、Coulomb GANで25個のガウス分布を学習してみます。

各種ハイパーパラメータの値は下表の通りです。ニューラルネットワークは普通のGANのときと同じものを使います。\(D\)の出力は、\(D\)に入力された位置のポテンシャルで、それは1次元の値なので使いまわせます。

Parameter nameValue
Samples per gaussian4000 (Total 100k)
Batch size512
\(d\)3
Adam lr for \(D\) and \(G\)0.001
Adam \(\beta_1\) for \(D\) and \(G\)0.5
Adam decay for \(D\) and \(G\)0.0001
\(\varepsilon\)3.0
Half life of \(\varepsilon\)5000 iteration
[1]ではバッチサイズを128にして実験してうまく学習できているようですが、今回の実験では学習に失敗することが多かったため、512に増やしています。ポテンシャルの計算をするときに使う真の点と偽の点の数はバッチサイズと等しいため、バッチサイズが小さいとポテンシャルの計算が雑になり、うまく\(D\)を学習できないのでしょう。ガウス分布が25個あるので、128だと真の点が平均で~5個しかなく、運悪く、ある分布に対する点が0個になってしまうことが起きていそうです。

500エポック後の結果は以下の通りです。まず、500エポック目の\(G\)で25000点を生成したときの分布が下の図です。

さらに、上の図にガウシアン25個のGMM(Gaussian Mixture Model)をフィットさせた結果が下の図です。全体にフィットしているガウス分布が1個あるため、背景が灰色になっています。ちょうど(-10,-10)の位置に集まっている点にフィットできていません。見た目ではまあまあ元の分布に近くなっています。
次の図は、500エポック目の\(D\)が推定するポテンシャル分布を表しています。真の点と偽の点がほぼ同じ位置に分布しているため、互いのポテンシャルが相殺し、大きな値は見られません。

一方、170エポック目では下の図のようにまだ(0, 0)と(10, 0)の位置に点が集まっていません。

このときの\(D\)が推定するポテンシャル分布(下図)を見てみると、その位置のポテンシャルの値が小さく(赤く)なっていることが分かります。真の点しかないので、ポテンシャルがその位置で小さくなります。他に小さくなっている部分をみてみると、(-20, -10)と(0, 20), (20, 0)があります。よーく上の図の分布を見てみると、その位置には点があるものの他の部分と違い点が集中していません。

このように、偽の点がない部分のポテンシャルが小さくなるので、ポテンシャルの小さい位置に偽の点が生成されるように\(G\)が学習されていきます。

下のAPNGによるアニメーションは、\(G\)の分布の発展の様子です。クリックすると再生されます(約34MB)。

少しずつ、抜けている部分に点が集まっていく様子が見て取れます。

実験 (バッチサイズ128)

バッチサイズ128、\(\varepsilon\)の半減期20000、Adam decay \(2.5\times 10^{-5}\)で試してみたときの、500エポック目の\(G\)が生成した点が下の図になります。イテレーションの回数が4倍に増えるので、半減期とAdam decayを調整していますが、それ以外は同じです。

中央(0, 0)の分布を再現できていないことが分かります。\(D\)が作るポテンシャルの分布は下の図のようになっていて、
見て分かるとおり、中央のポテンシャルは十分低いものの、その周辺にポテンシャルの壁がそびえ立っているため、中に点が入っていけない状況に陥っています。このケースでは70エポック目あたりから中央の周囲にポテンシャルの壁が出来始めていました。

というわけで、

今回の実験では、何かの違いにより[1]に書かれているハイパーパラメータではCoulomb GANによる学習を進めることができず、色々ハイパーパラメータを探すことになりましたが、見つかってしまえば何度実行しても学習できていました。ガウス分布の学習に関しては普通のGANより学習しやすいことは確かなようです。普通のGANよりバリエーションが豊かであることを確認してみたいのですが、まだまだ先は長そうです。

コード

今回の実験で使ったコードは以下の通りです。 Coulomb GANの著者らのTensorFlowベースのコード[1]を参考にしつつ実装しています。

importしているモジュールはここに置いています。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
# Coulomb GAN
import sys, os, math
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../keras-examples/src')
import numpy as np
from sklearn import mixture
from util.history import ExperimentHistory
from plot.gmm_plot import (plot_points_and_gmm, plot_potential,
                           plot_real_or_fake, plot_generator_distribution)
from gan.models_2dpoints import gan_model_relu, gan_model_elu
from gan.coulomb import CoulombPotentials
from gan.gaussian_mixture.datagen import gen_2D_samples_from_5x5gm, RandomSampler
from keras.models import Sequential
from keras.optimizers import Adam
import keras.backend as K

GENERATED_IMAGE_PATH = 'coulomb_images/'
SAMPLES_PER_GAUSSIAN = 4000
BATCH_SIZE = 128*4
NUM_EPOCH = 1000

def raw_loss(y_true, y_pred):
    return K.mean(y_pred, axis=-1)

class DiscriminatorLabelNormal:
    def __call__(self, points_real, points_fake):
        assert len(points_real) == BATCH_SIZE
        assert len(points_fake) == BATCH_SIZE
        return [1]*len(points_real) + [0]*len(points_fake)

    def get_eps(self):
        return 0

class DiscriminatorLableCoulomb:
    def __init__(self, eh):
        self.total_num_batches = 0
        self.cp = CoulombPotentials(eh.plummer_kernel_dim, eh.plummer_kernel_eps,
                                    eh.plummer_kernel_ignore_self_potential)
        self.eps_half_life = eh.plummer_kernel_eps_half_life

    def __call__(self, points_real, points_fake):
        self.cp.eps = eh.plummer_kernel_eps * math.pow(2, -self.total_num_batches/self.eps_half_life)
        self.total_num_batches += 1
        potential_real, potential_fake = self.cp(points_real, points_fake)
        return np.concatenate((potential_real, potential_fake))

    def get_eps(self):
        return self.cp.eps

def train(eh):
    # Generate training data and draw points
    X_train = gen_2D_samples_from_5x5gm(SAMPLES_PER_GAUSSIAN)
    gmm = mixture.GaussianMixture(n_components=25, covariance_type='full').fit(X_train)
    plot_points_and_gmm(X_train, gmm.predict(X_train), gmm.means_, gmm.covariances_,
                        'Gaussian Mixture', GENERATED_IMAGE_PATH+"gmm.png")

    # Random sampler for each batch and for plotting points of G
    rs = RandomSampler(4, "normal" if eh.random_with_normal_dist else "uniform")

    # Make an object for making correct labels of D
    dl = DiscriminatorLableCoulomb(eh) if eh.coulomb_gan else DiscriminatorLabelNormal()

    # Make generator G and discriminator D
    generator, discriminator = gan_model_elu()
    d_opt = Adam(lr=eh.disc_Adam_lr, beta_1=eh.disc_Adam_beta_1, decay=eh.disc_Adam_decay)
    g_opt = Adam(lr=eh.gen_Adam_lr,  beta_1=eh.gen_Adam_beta_1,  decay=eh.gen_Adam_decay)
    discriminator.compile(loss='mse' if eh.coulomb_gan else 'binary_crossentropy', optimizer=d_opt)
    discriminator.trainable = False
    cgan = Sequential([generator, discriminator]) # G+D with fixed weights of D
    cgan.compile(loss=raw_loss if eh.coulomb_gan else 'binary_crossentropy', optimizer=g_opt)

    num_batches = int(X_train.shape[0] / BATCH_SIZE)
    print('Number of batches:', num_batches)
    eh.write(GENERATED_IMAGE_PATH+"history.log", {"Generator":generator, "Discriminator":discriminator},
             {"Generator opt":g_opt, "Discriminator opt":d_opt}, None)
    for epoch in range(NUM_EPOCH):
        if eh.X_train_is_shuffled:
            np.random.shuffle(X_train)
        for index in range(num_batches):
            noise = np.array(rs(BATCH_SIZE))
            points_real = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            points_fake = generator.predict(noise, verbose=0)

            # Update discriminator
            X = np.concatenate((points_real, points_fake))
            Y = dl(points_real, points_fake)
            d_loss = discriminator.train_on_batch(X, Y)

            # Update generator
            noise = np.array(rs(BATCH_SIZE))
            g_loss = cgan.train_on_batch(noise, [1]*BATCH_SIZE) # labels are ignored if Coulomb GAN
            print("epoch: %d, batch: %d, g_loss: %e, d_loss: %e, eps: %e" %
                  (epoch, index, g_loss, d_loss, dl.get_eps()))

        plot_generator_distribution(generator, epoch, GENERATED_IMAGE_PATH, rs)
        if eh.coulomb_gan:
            plot_potential(discriminator, "Epoch={0:d}".format(epoch),
                           GENERATED_IMAGE_PATH+"p{0:03d}".format(epoch))
        else:
            plot_real_or_fake(discriminator, "Epoch={0:d}".format(epoch),
                              GENERATED_IMAGE_PATH+"p{0:03d}".format(epoch))

if __name__ == '__main__':
    eh = ExperimentHistory()
    eh.batch_size = BATCH_SIZE
    eh.samples_per_gaussian = SAMPLES_PER_GAUSSIAN
    eh.random_with_normal_dist = False
    eh.X_train_is_shuffled = True
    eh.plummer_kernel_dim = 3.0
    eh.plummer_kernel_eps = 3.0
    eh.plummer_kernel_eps_half_life = 5000.0
    eh.plummer_kernel_ignore_self_potential = True
    eh.disc_Adam_decay = 1e-4
    eh.disc_Adam_lr = 1e-3
    eh.disc_Adam_beta_1 = 0.5
    eh.gen_Adam_decay = eh.disc_Adam_decay
    eh.gen_Adam_lr = eh.disc_Adam_lr
    eh.gen_Adam_beta_1 = eh.disc_Adam_beta_1
    eh.coulomb_gan = True
    if not eh.coulomb_gan:
        GENERATED_IMAGE_PATH = 'generated_images/'
    if not os.path.exists(GENERATED_IMAGE_PATH):
        os.mkdir(GENERATED_IMAGE_PATH)
    train(eh)
    print("finish")