せっかく数字フォントで画像をつくったので、それでニューラルネットワークモデルを学習させてみます。
そのモデルで手書き数字のMNISTをどれだけ学習できるか、についても見てみます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | # This script trains a neural network model from font-based number images.
# The trained model is evaluated by validation data of MNIST.
import gc, keras, os, re
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation, Reshape
from keras.layers.convolutional import Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Flatten, Dropout
from keras.datasets import mnist
from PIL import Image, ImageDraw, ImageFont # For drawing font images
FONT_DIR = "C:/Windows/Fonts"
TRAIN_EPOCH = 30
# Image size is the same as MNIST
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28
# Make images and labels of 0-9 by using the font specified by font_name
def make_image(font_name):
font = ImageFont.truetype(FONT_DIR+"/{}".format(font_name), 25)
images = np.empty((0, IMAGE_WIDTH, IMAGE_HEIGHT))
labels = np.empty((0, 1))
for i in range(10): # Draw 0-9 as image
image = Image.new('RGB', (IMAGE_WIDTH, IMAGE_HEIGHT), (0, 0, 0))
draw = ImageDraw.Draw(image)
font_width, font_height = font.getsize(str(i))
draw.text(((IMAGE_WIDTH-font_width)/2, (IMAGE_HEIGHT-font_height)/2),
str(i), font=font, fill=(255, 255, 255))
ni = np.delete(np.asarray(image), [1, 2], 2) # Remove green and blue
ni = ni.reshape(1, ni.shape[0], ni.shape[1]) # Convert to 1x28x28 matrix
images = np.append(images, ni, axis=0)
labels = np.append(labels, np.array([i]).reshape(1, 1), axis=0)
return images, labels
# Collect images of numbers for all fonts
def make_images():
# Collect only true type fonts including numbers
rp = re.compile(".*ttf")
font_list = [f for f in os.listdir(FONT_DIR) if rp.match(f)
and (f != "webdings.ttf" and f != "wingding.ttf" and
f != "marlett.ttf" and f != "opens___.ttf" and
f != "symbol.ttf")]
images = np.empty((0, IMAGE_WIDTH, IMAGE_HEIGHT))
labels = np.empty((0, 1))
for font_name in font_list:
fimages, flabels = make_image(font_name)
images = np.append(images, fimages, axis=0)
labels = np.append(labels, flabels, axis=0)
return images, labels
# Define a discriminator model for numbers
def num_discriminator_model():
model = Sequential()
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=(IMAGE_WIDTH, IMAGE_HEIGHT, 1),
data_format='channels_last'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, (5, 5), strides=(2, 2), data_format='channels_last'))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(256))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
print(model.summary())
return model
if __name__ == '__main__':
# Make images of numbers from fonts
x_train, y_train = make_images()
x_train = (x_train.astype(np.float32) - 127.5)/127.5 # [0,255] --> [-1,1]
x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], x_train.shape[2], 1))
# Load MNIST data for validation
(_, _), (x_val, y_val) = mnist.load_data()
x_val = (x_val.astype(np.float32) - 127.5)/127.5 # [0,255] --> [-1,1]
x_val = x_val.reshape((x_val.shape[0], x_val.shape[1], x_val.shape[2], 1))
# Encode labels into 1-hot vectors
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_val = keras.utils.to_categorical(y_val, num_classes=10)
# Make, compile and train a model
model = num_discriminator_model()
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
hist = model.fit(x_train, y_train, validation_data=(x_val, y_val),
epochs=TRAIN_EPOCH, batch_size=32).history
# Evaluate the model by train and validation data
score = model.evaluate(x_train, y_train, batch_size=32)
print("\ntrain acc : ", score)
score = model.evaluate(x_val, y_val, batch_size=32)
print("\nmnist val acc : ", score)
# Write training history into a file
f = open("history.dat", mode="w")
f.write("#epoch train-loss train-acc val-loss val-acc\n")
for v in range(0, TRAIN_EPOCH):
f.write("{0} {1:10.6f} {2:10.6f} {3:10.6f} {4:10.6f}\n"
.format(v, hist["loss"][v], hist["acc"][v],
hist["val_loss"][v], hist["val_acc"][v]))
f.close()
gc.collect() # To suppress error messages of TensorFlow
|
乱数の影響で実行するたびに精度は変わりますが、
Loss | Accuracy | |
---|---|---|
Train | 0.00267 | 0.99852 |
MNIST | 6.10193 | 0.52960 |
0 件のコメント :
コメントを投稿