2018/01/27

Coulomb GANでMNIST

はじめに


Coulomb GANでMNISTの手書き数字を生成してみます。生成した画像の良さはFréchet Inception Distance (FID)を使ってとりあえずは測定できるので、普通のGANと比較してみます。

生成器と識別器


生成器と識別器は、以前の利用したものとほぼ同じものを利用します。 Kerasのコードで、生成器は、
g = Sequential()
g.add(Dense(1024, input_dim=100))
g.add(BatchNormalization())
g.add(Activation('relu'))
g.add(Dense(128*7*7))
g.add(BatchNormalization())
g.add(Activation('relu'))
g.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
g.add(UpSampling2D((2, 2), data_format='channels_first'))
g.add(Conv2D(64, (5, 5), padding='same', data_format='channels_first'))
g.add(BatchNormalization())
g.add(Activation('relu'))
g.add(UpSampling2D((2, 2), data_format='channels_first'))
g.add(Conv2D(1, (5, 5), padding='same', data_format='channels_first'))
g.add(Activation('tanh'))
識別器は、
d = Sequential()
d.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(1, 28, 28), 
                 data_format='channels_first'))
d.add(LeakyReLU(0.2))
d.add(Conv2D(128, (5, 5), strides=(2, 2), data_format='channels_first'))
d.add(LeakyReLU(0.2))
d.add(Flatten())
d.add(Dense(256))
d.add(LeakyReLU(0.2))
d.add(Dropout(0.5))
d.add(Dense(1))
を使います。最後のsigmoidは省略しています。Coulomb GANはポテンシャルを模擬するので、sigmoidがあると模擬ができません。一方、普通のGANでは真偽を1と0で表すのでsigmoidがあったほうが良さそうですが、試してみたところ問題なく学習できました。

実験結果


さっそく、Coulomb GANと普通のGANを比較してみます。結果は下表の通りです。 FIDの比較対象は全てMNISTの訓練用のデータです。Coulomb GANは自身のポテンシャルを無視する設定で動作させました。また、Plummer kernelの\(\varepsilon\)の半減期は5000、次元は3に固定しました。FIDは3000サンプルで計算しました。
TypeParametersFIDGenerated images
Standard GAN Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-4
\(D\), Adam lr =1e-5
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-4
\(G\), Adam lr =1e-4
14.66 Epoch=99, Batch=400
Coulomb GAN
No. 1
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-3
\(D\), Adam lr =1e-3
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-3
\(G\), Adam lr =1e-3
Plummer kernel \(\varepsilon\)=1.0
52.32 Epoch=99, Batch=400
Coulomb GAN
No. 2
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-3
\(D\), Adam lr =1e-4
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-3
\(G\), Adam lr =1e-3
Plummer kernel \(\varepsilon\)=1.0
90.82 Epoch=99, Batch=400
Coulomb GAN
No. 3
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-4
\(D\), Adam lr =1e-3
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-3
\(G\), Adam lr =1e-4
Plummer kernel \(\varepsilon\)=1.0
106.5 Epoch=99, Batch=400
Coulomb GAN
No. 4
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-3
\(D\), Adam lr =1e-3
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-3
\(G\), Adam lr =1e-3
Plummer kernel \(\varepsilon\)=0.1
108.0 Epoch=99, Batch=400
Coulomb GAN
No. 5
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-2
\(D\), Adam lr =1e-3
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-2
\(G\), Adam lr =1e-3
Plummer kernel \(\varepsilon\)=10.0
118.4 Epoch=70, Batch=0
Coulomb GAN
No. 6
Batch size=128
\(D\), Adam \(\beta_1\) =0.5
\(D\), Adam decay =1e-4
\(D\), Adam lr =1e-3
\(G\), Adam \(\beta_1\) =0.5
\(G\), Adam decay =1e-4
\(G\), Adam lr =1e-3
Plummer kernel \(\varepsilon\)=1.0
149.9 Epoch=99, Batch=400
FIDで比較してみると、普通のGANはCoulomb GANよりFIDが小さく、元の分布を良く再現できているようです。実際、画像をよく見てみると、Coulomb GANのうち最もFIDが小さい実験1の場合でも、崩れ気味かつ1ピクセルだけ輝度が高い部分が残っています。Coulomb GANの実験4はFIDが108.0と大きめですが、数字の背景の点々(ノイズ)も少なく、線もはっきりとしています。ただ、なんというか、数字を一定のルールで崩したような感じに見えます。他の実験結果は、ぜんぜんダメです。

\(G\)の学習速度については、普通のGANでは、15エポックもすると数字っぽく主観で綺麗な画像が出てくるのに対し、Coulomb GANは実験1のケースであっても95エポックくらいまで到達しないと崩れたりノイズが多い画像が出力されます。パラメータはここに挙げたもの以外も試してみましたが、普通のGANと同じ速さで\(G\)を学習できるパラメータを見つけることはできませんでした。

まとめ


点の分布の学習では安定して良さげな結果を出していたCoulomb GANでしたが、今回のMNISTの画像生成ではいまいちな結果となりました。真の画像よりバリエーションが豊富な画像を生成したいときはFIDは適度に大きくなる必要がありますが、どのくらいのFIDになると良いのかは今回の実験では分かりませんでした。

コード


今回の実験で使ったコードです。最後のパラメータ部分は実験ごとに書き換えています。
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# -*- coding: utf-8 -*-
# Coulomb GAN
import sys, os, math
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../keras-examples/src')
import numpy as np
from util.history import ExperimentHistory
from gan.coulomb import CoulombPotentials
from gan.gaussian_mixture.datagen import RandomSampler
from keras.models import Sequential
from keras.optimizers import Adam
import keras.backend as K
from keras.models import Sequential
from keras.layers import Dense, Activation, Reshape
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Flatten, Dropout
from keras.datasets import mnist
from PIL import Image

GENERATED_IMAGE_PATH = 'coulomb_images/'
BATCH_SIZE = 32*4
NUM_EPOCH = 100

def raw_loss(y_true, y_pred):
    return K.mean(y_pred, axis=-1)

def combine_images(generated_images):
    total = generated_images.shape[0]
    cols = int(math.sqrt(total))
    rows = math.ceil(float(total)/cols)
    width, height = generated_images.shape[2:]
    combined_image = np.zeros((height*rows, width*cols), dtype=generated_images.dtype)

    for index, image in enumerate(generated_images):
        i = int(index/cols)
        j = index % cols
        combined_image[width*i:width*(i+1), height*j:height*(j+1)] = image[0, :, :]
    return combined_image

def gan_model_mnist(initializer='glorot_uniform'):
    """
    return: (generator, discriminator)
    """
    g = Sequential()
    g.add(Dense(1024, input_dim=100))
    g.add(BatchNormalization())
    g.add(Activation('relu'))
    g.add(Dense(128*7*7))
    g.add(BatchNormalization())
    g.add(Activation('relu'))
    g.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
    g.add(UpSampling2D((2, 2), data_format='channels_first'))
    g.add(Conv2D(64, (5, 5), padding='same', data_format='channels_first'))
    g.add(BatchNormalization())
    g.add(Activation('relu'))
    g.add(UpSampling2D((2, 2), data_format='channels_first'))
    g.add(Conv2D(1, (5, 5), padding='same', data_format='channels_first'))
    g.add(Activation('tanh'))
    print(g.summary())

    d = Sequential()
    d.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(1, 28, 28), 
                     data_format='channels_first'))
    d.add(LeakyReLU(0.2))
    d.add(Conv2D(128, (5, 5), strides=(2, 2), data_format='channels_first'))
    d.add(LeakyReLU(0.2))
    d.add(Flatten())
    d.add(Dense(256))
    d.add(LeakyReLU(0.2))
    d.add(Dropout(0.5))
    d.add(Dense(1))
    #d.add(Activation('sigmoid'))
    print(d.summary())
    return g, d

class DiscriminatorLabelNormal:
    def __call__(self, points_real, points_fake):
        assert len(points_real) == BATCH_SIZE
        assert len(points_fake) == BATCH_SIZE
        return [1]*len(points_real) + [0]*len(points_fake)

    def get_eps(self):
        return 0

class DiscriminatorLabelCoulomb:
    def __init__(self, eh):
        self.total_num_batches = 0
        self.cp = CoulombPotentials(eh.plummer_kernel_dim, eh.plummer_kernel_eps,
                                    eh.plummer_kernel_ignore_self_potential)
        self.eps_half_life = eh.plummer_kernel_eps_half_life

    def __call__(self, points_real, points_fake):
        self.cp.eps = eh.plummer_kernel_eps * math.pow(2, -self.total_num_batches/self.eps_half_life)
        self.total_num_batches += 1
        potential_real, potential_fake = self.cp(points_real, points_fake)
        return np.concatenate((potential_real, potential_fake))

    def get_eps(self):
        return self.cp.eps

def train(eh):
    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1], X_train.shape[2])

    # Random sampler for each batch
    rs = RandomSampler(100, "normal" if eh.random_with_normal_dist else "uniform")

    # Make an object for making correct labels of D
    dl = DiscriminatorLabelCoulomb(eh) if eh.coulomb_gan else DiscriminatorLabelNormal()

    # Make a generator G and a discriminator D
    generator, discriminator = gan_model_mnist()
    d_opt = Adam(lr=eh.disc_Adam_lr, beta_1=eh.disc_Adam_beta_1, decay=eh.disc_Adam_decay)
    g_opt = Adam(lr=eh.gen_Adam_lr,  beta_1=eh.gen_Adam_beta_1,  decay=eh.gen_Adam_decay)
    discriminator.compile(loss='mse' if eh.coulomb_gan else 'binary_crossentropy', optimizer=d_opt)
    discriminator.trainable = False
    cgan = Sequential([generator, discriminator]) # G+D with fixed weights of D
    cgan.compile(loss=raw_loss if eh.coulomb_gan else 'binary_crossentropy', optimizer=g_opt)
    num_batches = int(X_train.shape[0] / BATCH_SIZE)
    print('Number of batches:', num_batches)
    eh.write(GENERATED_IMAGE_PATH+"history.log", {"Generator":generator, "Discriminator":discriminator},
             {"Generator opt":g_opt, "Discriminator opt":d_opt}, None)
    for epoch in range(NUM_EPOCH):
        if eh.X_train_is_shuffled:
            np.random.shuffle(X_train)
        for index in range(num_batches):
            noise = np.array(rs(BATCH_SIZE))
            points_real = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            points_fake = generator.predict(noise, verbose=0)

            # Output generated images
            if index % 200 == 0:
                image = combine_images(points_fake)
                image = image*127.5 + 127.5
                if not os.path.exists(GENERATED_IMAGE_PATH):
                    os.mkdir(GENERATED_IMAGE_PATH)
                Image.fromarray(image.astype(np.uint8))\
                    .save(GENERATED_IMAGE_PATH+"%04d_%04d.png" % (epoch, index))

                generator.save(GENERATED_IMAGE_PATH+"generator.model")
                discriminator.save(GENERATED_IMAGE_PATH+"discriminator.model")

            # Update a discriminator
            X = np.concatenate((points_real, points_fake))
            Y = dl(points_real, points_fake)
            d_loss = discriminator.train_on_batch(X, Y)

            # Update a generator
            noise = np.array(rs(BATCH_SIZE))
            g_loss = cgan.train_on_batch(noise, [1]*BATCH_SIZE) # labels are ignored if Coulomb GAN
            print("epoch: %d, batch: %d, g_loss: %e, d_loss: %e, eps: %e" %
                  (epoch, index, g_loss, d_loss, dl.get_eps()))

if __name__ == '__main__':
    GENERATED_IMAGE_PATH = 'coulomb_images/'
    eh = ExperimentHistory()
    eh.batch_size = BATCH_SIZE
    eh.random_with_normal_dist = False
    eh.X_train_is_shuffled = True
    eh.plummer_kernel_dim = 3.0
    eh.plummer_kernel_eps = 0.1
    eh.plummer_kernel_eps_half_life = 5000.0
    eh.plummer_kernel_ignore_self_potential = True
    eh.disc_Adam_decay = 1e-3
    eh.disc_Adam_lr = 1e-3
    eh.disc_Adam_beta_1 = 0.5
    eh.gen_Adam_decay = 1e-3
    eh.gen_Adam_lr = 1e-3
    eh.gen_Adam_beta_1 = 0.5
    eh.coulomb_gan = True
    if not eh.coulomb_gan:
        GENERATED_IMAGE_PATH = 'generated_images/'
    if not os.path.exists(GENERATED_IMAGE_PATH):
        os.mkdir(GENERATED_IMAGE_PATH)
    train(eh)