2017/12/03

WGAN

はじめに

Wasserstein GAN (WGAN) [1]を試してみました。 詳細は色々なサイト[2, 3, 4]で解説されているので省略します。

理論的背景はともかく、実装の手間は、普通のGANと前回試したCoulomb GANを足して2で割った程度なので、試すこと自体は難しくありません。

比較

普通のGANとWGAN、Coulomb GANには識別器\(D\)のLossの計算について以下のような違いがあります。

  • 普通のGAN(Standard GANとかRegular GAN、Vanilla GANとか呼ばれています)

    真の点について1を、偽の点について0を出力するように学習します。Lossは交差エントロピーです。\(D\)が出力する値の分布についての制限はなにもありません。

    f r サンプル点の位置 f r 1 0 の出力 D
  • WGAN

    真の点の値を小さく、偽の点の値を大きくするようにLossを計算します。つまり、真の点の値から偽の点の値を引いた値(下図では\(a-b\))をLossとします。\(D\)の重みに対してクリッピングを行うので、その出力の分布には制限がつきます。

    f r サンプル点の位置 f r a の出力 D b
  • Coulomb GAN

    真の点を正の電荷、偽の点を負の電荷として電位を計算し、電位と\(D\)の出力との二乗誤差をLossとします。

    電荷(サンプル点)の位置 マイナス電荷に対するポテンシャル

実験

失敗例

WGANでもハイパーパラメータを間違えると分布を模擬できないという例です。

ParameterValueImages
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.0001
0.5
0.0001
0.0001
1
0.1
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.0001
0.5
0.0001
0.0001
1
0.01
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.01
0.5
0.0001
0.01
1
0.01
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.01
0.5
0.0001
0.01
1
0.1
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.001
0.5
0.0001
0.001
1
0.1
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.0001
0.001
0.5
0.0001
0.001
5
0.1
このあと、画面外に\(G\)の分布が移動してしまいました。

成功例

どうにか成功したパラメータが以下です。

ParameterValue
batch_size
disc \(\beta_1\)
disc decay
disc lr
gen \(\beta_1\)
gen decay
gen lr
\(n_{\rm critic}\)
clip of weights of \(D\)
512
0.5
0.00001
0.0001
0.5
0.00001
0.0001
1
0.1
結果が下図です。左側は100エポック目から100エポックごとの\(G\)の分布を表しています。右側は\(D\)の分布を表しています。最後だけ、999エポック目です。
\(G\)については、綺麗に当てはまっているとはいえませんが、再現したい25個のガウシアンそれぞれの中心付近を通過するようには分布しています。また、\(D\)については、中央から放射状に等値線が伸びていることがわかります。おおよそ谷の部分に\(G\)が分布しています。 \(G\)の発展の仕方がCoulomb GANとは異なるところが面白いところです。

クリッピングの省略

識別器\(D\)の重みのクリッピングを省略してみました。省略するとWGANではなくなっていますが、なくてもGANとしては機能しそうに思えます。クリッピング以外のハイパーパラメータを上記の成功例と同じ値に設定しています。 1回目は失敗し、

のようになりました。

2回目は成功し、

となりました。\(D\)の分布を見てみると、クリッピングがないので、学習が進むにつれて値が大きくなっていきます。画像の数字はつぶれて読みにくいので、等値線の最小値と最大値を下表にまとめておきました。
EpochMinMax
1000600
200-40002000
300-1600012000
400-3000040000
500-6000060000
600-10000050000
700-12000090000
800-160000120000
900-200000100000
999-180000120000

まとめ

WGANでも25個のガウス分布を再現する\(G\)をある程度は学習できることを確認できました。また、クリッピングがなくても、学習できることがあることも確認できました。

コード

実験に使ったコードは以下の通りです。importしているコードは、ここに置いています。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
# -*- coding: utf-8 -*-
# Wasserstein GAN
import sys, os, math
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../keras-examples/src')
import numpy as np
from sklearn import mixture
from util.history import ExperimentHistory
from plot.gmm_plot import (plot_points_and_gmm, plot_potential,
                           plot_real_or_fake, plot_generator_distribution)
from gan.models_2dpoints import gan_model_relu, gan_model_elu
from gan.gaussian_mixture.datagen import gen_2D_samples_from_5x5gm, RandomSampler
from keras.models import Sequential
from keras.optimizers import Adam
import keras.backend as K

GENERATED_IMAGE_PATH = 'wgan_images/'
SAMPLES_PER_GAUSSIAN = 4000
BATCH_SIZE = 128*4
NUM_EPOCH = 1000

def mult_loss(y_true, y_pred):
    return K.mean(y_pred*y_true, axis=-1)

def raw_loss(y_true, y_pred):
    return K.mean(y_pred, axis=-1)

class DiscriminatorLabelW:
    def __call__(self, points_real, points_fake):
        assert len(points_real) == BATCH_SIZE
        assert len(points_fake) == BATCH_SIZE
        return [1]*len(points_real) + [-1]*len(points_fake)

    def get_eps(self):
        return 0

def train(eh):
    # Generate training data and draw points
    X_train = gen_2D_samples_from_5x5gm(SAMPLES_PER_GAUSSIAN)
    gmm = mixture.GaussianMixture(n_components=25, covariance_type='full').fit(X_train)
    plot_points_and_gmm(X_train, gmm.predict(X_train), gmm.means_, gmm.covariances_,
                        'Gaussian Mixture', GENERATED_IMAGE_PATH+"gmm.png")

    # Random sampler for each batch and for plotting points of G
    rs = RandomSampler(4, "normal" if eh.random_with_normal_dist else "uniform")

    # Make an object for making correct labels of D
    dl = DiscriminatorLabelW()

    # Make generator G and discriminator D
    generator, discriminator = gan_model_elu()
    d_opt = Adam(lr=eh.disc_Adam_lr, beta_1=eh.disc_Adam_beta_1, decay=eh.disc_Adam_decay)
    g_opt = Adam(lr=eh.gen_Adam_lr,  beta_1=eh.gen_Adam_beta_1,  decay=eh.gen_Adam_decay)
    discriminator.compile(loss=mult_loss, optimizer=d_opt)
    discriminator.trainable = False
    cgan = Sequential([generator, discriminator]) # G+D with fixed weights of D
    cgan.compile(loss=raw_loss, optimizer=g_opt)

    num_batches = int(X_train.shape[0] / BATCH_SIZE)
    print('Number of batches:', num_batches)
    eh.write(GENERATED_IMAGE_PATH+"history.log", {"Generator":generator, "Discriminator":discriminator},
             {"Generator opt":g_opt, "Discriminator opt":d_opt}, None)
    for epoch in range(NUM_EPOCH):
        if eh.X_train_is_shuffled:
            np.random.shuffle(X_train)
        for index in range(num_batches):
            for dindex in range(eh.n_repeat_update_D):
                noise = np.array(rs(BATCH_SIZE))
                points_real = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
                points_fake = generator.predict(noise, verbose=0)
                X = np.concatenate((points_real, points_fake))
                Y = dl(points_real, points_fake)
                d_loss = discriminator.train_on_batch(X, Y)
                discriminator.set_weights([np.clip(w, -eh.weight_clip, eh.weight_clip)
                    for w in discriminator.get_weights()])

            # Update generator
            noise = np.array(rs(BATCH_SIZE))
            g_loss = cgan.train_on_batch(noise, [1]*BATCH_SIZE) # labels are ignored
            print("epoch: %d, batch: %d, g_loss: %e, d_loss: %e, eps: %e" %
                  (epoch, index, g_loss, d_loss, dl.get_eps()))

        plot_generator_distribution(generator, epoch, GENERATED_IMAGE_PATH, rs)
        plot_potential(discriminator, "Epoch={0:d}".format(epoch),
                        GENERATED_IMAGE_PATH+"p{0:03d}".format(epoch), "auto")

if __name__ == '__main__':
    eh = ExperimentHistory()
    eh.batch_size = BATCH_SIZE
    eh.samples_per_gaussian = SAMPLES_PER_GAUSSIAN
    eh.random_with_normal_dist = False
    eh.X_train_is_shuffled = True
    eh.disc_Adam_decay = 1e-5
    eh.disc_Adam_lr = 1e-4
    eh.disc_Adam_beta_1 = 0.5
    eh.gen_Adam_decay = eh.disc_Adam_decay
    eh.gen_Adam_lr = eh.disc_Adam_lr
    eh.gen_Adam_beta_1 = eh.disc_Adam_beta_1
    eh.weight_clip = 0.1 #float("inf")
    eh.n_repeat_update_D = 1
    if not os.path.exists(GENERATED_IMAGE_PATH):
        os.mkdir(GENERATED_IMAGE_PATH)
    train(eh)
    print("finish")

参考文献

[1] https://arxiv.org/abs/1701.07875
[2] http://musyoku.github.io/2017/02/06/Wasserstein-GAN/
[3] http://yusuke-ujitoko.hatenablog.com/entry/2017/05/20/145924
[4] http://www.monthly-hack.com/entry/2017/02/28/200546

0 件のコメント :