はじめに
Wasserstein GAN (WGAN) [1]を試してみました。 詳細は色々なサイト[2, 3, 4]で解説されているので省略します。 理論的背景はともかく、実装の手間は、普通のGANと前回試したCoulomb GANを足して2で割った程度なので、試すこと自体は難しくありません。比較
普通のGANとWGAN、Coulomb GANには識別器\(D\)のLossの計算について以下のような違いがあります。- 普通のGAN(Standard GANとかRegular GAN、Vanilla GANとか呼ばれています)
真の点について1を、偽の点について0を出力するように学習します。Lossは交差エントロピーです。\(D\)が出力する値の分布についての制限はなにもありません。
- WGAN
真の点の値を小さく、偽の点の値を大きくするようにLossを計算します。つまり、真の点の値から偽の点の値を引いた値(下図では\(a-b\))をLossとします。\(D\)の重みに対してクリッピングを行うので、その出力の分布には制限がつきます。
- Coulomb GAN
真の点を正の電荷、偽の点を負の電荷として電位を計算し、電位と\(D\)の出力との二乗誤差をLossとします。
実験
失敗例
WGANでもハイパーパラメータを間違えると分布を模擬できないという例です。Parameter | Value | Images |
---|---|---|
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.0001 0.5 0.0001 0.0001 1 0.1 | |
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.0001 0.5 0.0001 0.0001 1 0.01 | |
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.01 0.5 0.0001 0.01 1 0.01 | |
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.01 0.5 0.0001 0.01 1 0.1 | |
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.001 0.5 0.0001 0.001 1 0.1 | |
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.0001 0.001 0.5 0.0001 0.001 5 0.1 | このあと、画面外に\(G\)の分布が移動してしまいました。 |
成功例
どうにか成功したパラメータが以下です。Parameter | Value |
---|---|
batch_size disc \(\beta_1\) disc decay disc lr gen \(\beta_1\) gen decay gen lr \(n_{\rm critic}\) clip of weights of \(D\) |
512 0.5 0.00001 0.0001 0.5 0.00001 0.0001 1 0.1 |
クリッピングの省略
識別器\(D\)の重みのクリッピングを省略してみました。省略するとWGANではなくなっていますが、なくてもGANとしては機能しそうに思えます。クリッピング以外のハイパーパラメータを上記の成功例と同じ値に設定しています。 1回目は失敗し、 のようになりました。 2回目は成功し、 となりました。\(D\)の分布を見てみると、クリッピングがないので、学習が進むにつれて値が大きくなっていきます。画像の数字はつぶれて読みにくいので、等値線の最小値と最大値を下表にまとめておきました。Epoch | Min | Max |
---|---|---|
100 | 0 | 600 |
200 | -4000 | 2000 |
300 | -16000 | 12000 |
400 | -30000 | 40000 |
500 | -60000 | 60000 |
600 | -100000 | 50000 |
700 | -120000 | 90000 |
800 | -160000 | 120000 |
900 | -200000 | 100000 |
999 | -180000 | 120000 |
まとめ
WGANでも25個のガウス分布を再現する\(G\)をある程度は学習できることを確認できました。また、クリッピングがなくても、学習できることがあることも確認できました。コード
実験に使ったコードは以下の通りです。importしているコードは、ここに置いています。1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # -*- coding: utf-8 -*-
# Wasserstein GAN
import sys, os, math
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../keras-examples/src')
import numpy as np
from sklearn import mixture
from util.history import ExperimentHistory
from plot.gmm_plot import (plot_points_and_gmm, plot_potential,
plot_real_or_fake, plot_generator_distribution)
from gan.models_2dpoints import gan_model_relu, gan_model_elu
from gan.gaussian_mixture.datagen import gen_2D_samples_from_5x5gm, RandomSampler
from keras.models import Sequential
from keras.optimizers import Adam
import keras.backend as K
GENERATED_IMAGE_PATH = 'wgan_images/'
SAMPLES_PER_GAUSSIAN = 4000
BATCH_SIZE = 128*4
NUM_EPOCH = 1000
def mult_loss(y_true, y_pred):
return K.mean(y_pred*y_true, axis=-1)
def raw_loss(y_true, y_pred):
return K.mean(y_pred, axis=-1)
class DiscriminatorLabelW:
def __call__(self, points_real, points_fake):
assert len(points_real) == BATCH_SIZE
assert len(points_fake) == BATCH_SIZE
return [1]*len(points_real) + [-1]*len(points_fake)
def get_eps(self):
return 0
def train(eh):
# Generate training data and draw points
X_train = gen_2D_samples_from_5x5gm(SAMPLES_PER_GAUSSIAN)
gmm = mixture.GaussianMixture(n_components=25, covariance_type='full').fit(X_train)
plot_points_and_gmm(X_train, gmm.predict(X_train), gmm.means_, gmm.covariances_,
'Gaussian Mixture', GENERATED_IMAGE_PATH+"gmm.png")
# Random sampler for each batch and for plotting points of G
rs = RandomSampler(4, "normal" if eh.random_with_normal_dist else "uniform")
# Make an object for making correct labels of D
dl = DiscriminatorLabelW()
# Make generator G and discriminator D
generator, discriminator = gan_model_elu()
d_opt = Adam(lr=eh.disc_Adam_lr, beta_1=eh.disc_Adam_beta_1, decay=eh.disc_Adam_decay)
g_opt = Adam(lr=eh.gen_Adam_lr, beta_1=eh.gen_Adam_beta_1, decay=eh.gen_Adam_decay)
discriminator.compile(loss=mult_loss, optimizer=d_opt)
discriminator.trainable = False
cgan = Sequential([generator, discriminator]) # G+D with fixed weights of D
cgan.compile(loss=raw_loss, optimizer=g_opt)
num_batches = int(X_train.shape[0] / BATCH_SIZE)
print('Number of batches:', num_batches)
eh.write(GENERATED_IMAGE_PATH+"history.log", {"Generator":generator, "Discriminator":discriminator},
{"Generator opt":g_opt, "Discriminator opt":d_opt}, None)
for epoch in range(NUM_EPOCH):
if eh.X_train_is_shuffled:
np.random.shuffle(X_train)
for index in range(num_batches):
for dindex in range(eh.n_repeat_update_D):
noise = np.array(rs(BATCH_SIZE))
points_real = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
points_fake = generator.predict(noise, verbose=0)
X = np.concatenate((points_real, points_fake))
Y = dl(points_real, points_fake)
d_loss = discriminator.train_on_batch(X, Y)
discriminator.set_weights([np.clip(w, -eh.weight_clip, eh.weight_clip)
for w in discriminator.get_weights()])
# Update generator
noise = np.array(rs(BATCH_SIZE))
g_loss = cgan.train_on_batch(noise, [1]*BATCH_SIZE) # labels are ignored
print("epoch: %d, batch: %d, g_loss: %e, d_loss: %e, eps: %e" %
(epoch, index, g_loss, d_loss, dl.get_eps()))
plot_generator_distribution(generator, epoch, GENERATED_IMAGE_PATH, rs)
plot_potential(discriminator, "Epoch={0:d}".format(epoch),
GENERATED_IMAGE_PATH+"p{0:03d}".format(epoch), "auto")
if __name__ == '__main__':
eh = ExperimentHistory()
eh.batch_size = BATCH_SIZE
eh.samples_per_gaussian = SAMPLES_PER_GAUSSIAN
eh.random_with_normal_dist = False
eh.X_train_is_shuffled = True
eh.disc_Adam_decay = 1e-5
eh.disc_Adam_lr = 1e-4
eh.disc_Adam_beta_1 = 0.5
eh.gen_Adam_decay = eh.disc_Adam_decay
eh.gen_Adam_lr = eh.disc_Adam_lr
eh.gen_Adam_beta_1 = eh.disc_Adam_beta_1
eh.weight_clip = 0.1 #float("inf")
eh.n_repeat_update_D = 1
if not os.path.exists(GENERATED_IMAGE_PATH):
os.mkdir(GENERATED_IMAGE_PATH)
train(eh)
print("finish")
|
参考文献
[1] https://arxiv.org/abs/1701.07875[2] http://musyoku.github.io/2017/02/06/Wasserstein-GAN/
[3] http://yusuke-ujitoko.hatenablog.com/entry/2017/05/20/145924
[4] http://www.monthly-hack.com/entry/2017/02/28/200546
0 件のコメント :
コメントを投稿