2017年9月2日土曜日

数字画像の水増し

ただ単にフォントで描いた画像で学習したモデルではMNISTの手書き数字はうまく認識できなかったので、フォントで描いた画像を加工して水増しする方法(Data augmentation)を試してみます。

画像の加工には、KerasのImageDataGeneratorを使います。前回のコードとほぼ同じですが、加工部分が異なっています。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# This script trains a neural network model from font-based number images.
# Image data are augmented by changing their shape.
# The trained model is evaluated by validation data of MNIST.
import gc, keras, math, os, re
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Reshape
from keras.layers.convolutional import Conv2D
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Flatten, Dropout
from keras.datasets import mnist
from PIL import Image, ImageDraw, ImageFont # For drawing font images

FONT_DIR = "C:/Windows/Fonts"
TRAIN_EPOCH = 30
BATCH_SIZE = 32

# Image size is the same as MNIST
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28

# Make images and labels of 0-9 by using the font specified by font_name
def make_image(font_name):
    font = ImageFont.truetype(FONT_DIR+"/{}".format(font_name), 25)
    images = np.empty((0, IMAGE_WIDTH, IMAGE_HEIGHT))
    labels = np.empty((0, 1))
    for i in range(10): # Draw 0-9 as image
        image = Image.new('RGB', (IMAGE_WIDTH, IMAGE_HEIGHT), (0, 0, 0))
        draw = ImageDraw.Draw(image)
        font_width, font_height = font.getsize(str(i))
        draw.text(((IMAGE_WIDTH-font_width)/2, (IMAGE_HEIGHT-font_height)/2),
                  str(i), font=font, fill=(255, 255, 255))
        ni = np.delete(np.asarray(image), [1, 2], 2) # Remove green and blue
        ni = ni.reshape(1, ni.shape[0], ni.shape[1]) # Convert to 1x28x28 matrix
        images = np.append(images, ni, axis=0)
        labels = np.append(labels, np.array([i]).reshape(1, 1), axis=0)
    return images, labels

# Collect images of numbers for all fonts
def make_images():
    # Collect only true type fonts including numbers
    rp = re.compile(".*ttf")
    font_list = [f for f in os.listdir(FONT_DIR) if rp.match(f)
                 and (f != "webdings.ttf" and f != "wingding.ttf" and 
                      f != "marlett.ttf" and f != "opens___.ttf" and
                      f != "symbol.ttf")]
    images = np.empty((0, IMAGE_WIDTH, IMAGE_HEIGHT))
    labels = np.empty((0, 1))
    for font_name in font_list:
        fimages, flabels = make_image(font_name)
        images = np.append(images, fimages, axis=0)
        labels = np.append(labels, flabels, axis=0)
    return images, labels

# Define a discriminator model for numbers
def num_discriminator_model():
    model = Sequential()
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                     input_shape=(IMAGE_WIDTH, IMAGE_HEIGHT, 1), 
                     data_format='channels_last'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(128, (5, 5), strides=(2, 2), data_format='channels_last'))
    model.add(LeakyReLU(0.2))
    model.add(Flatten())
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation('softmax'))
    print(model.summary())
    return model

if __name__ == '__main__':
    # Make images of numbers from fonts
    x_train, y_train = make_images()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5 # [0,255] --> [-1,1]
    x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))

    # Load MNIST data for validation
    (_, _), (x_val, y_val) = mnist.load_data()
    x_val = (x_val.astype(np.float32) - 127.5)/127.5 # [0,255] --> [-1,1]
    x_val = x_val.reshape((x_val.shape[0], x_val.shape[1], x_val.shape[2], 1))

    # Encode labels into 1-hot vectors
    y_train = keras.utils.to_categorical(y_train, num_classes=10)
    y_val = keras.utils.to_categorical(y_val, num_classes=10)

    # For data augmentation
    datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=math.pi/4, # 45 degree
        zoom_range=0.3,
        fill_mode="constant",
        cval=-1, # constant value for fill_mode
        )

    # Make, compile and train a model
    model = num_discriminator_model()
    model.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])
    hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=BATCH_SIZE),
                               steps_per_epoch=len(x_train)/BATCH_SIZE, epochs=TRAIN_EPOCH,
                               validation_data=(x_val, y_val)).history

    # Evaluate the model by train and validation data
    score = model.evaluate(x_train, y_train, batch_size=32)
    print("\ntrain acc : ", score)
    score = model.evaluate(x_val, y_val, batch_size=32)
    print("\nmnist val acc : ", score)

    # Write training history into a file
    f = open("history.dat", mode="w")
    f.write("#epoch train-loss train-acc val-loss val-acc\n")
    for v in range(0, TRAIN_EPOCH):
        f.write("{0} {1:10.6f} {2:10.6f} {3:10.6f} {4:10.6f}\n"
                .format(v, hist["loss"][v], hist["acc"][v],
                        hist["val_loss"][v], hist["val_acc"][v]))
    f.close()
    gc.collect() # To suppress error messages of TensorFlow
30エポック学習させたモデルを使ったときのLossとAccuracyは次のようになりました。
AugmentationLossAccuracy
TrainOff0.002670.99852
TrainOn0.091110.97811
MNISTOff6.101930.52960
MNISTOn1.019090.71590
学習データを加工しない場合に比べて、ずいぶんMNISTの結果がよくなりました。一方で、学習データに対する性能は低下しています。

では、学習中のLossの変化を見てみましょう。

学習データの加工なしの場合に比べて、MNISTのLossもある程度は低下しています。

学習中のAccuracyの変化は次のようになりました。

MNISTに関しては、ざっくり70%くらいの精度でしょうか。データ加工なしの場合が約50%だったので、結構改善しました。

0 件のコメント :