2017年9月9日土曜日

分類に失敗した画像の出力

MNISTの画像の識別器を学習させていると、精度(accuracy)だけではなく、誤って識別した画像の一覧を見たい場合があります。

以下のコードを使うと、学習済みのKerasのモデルをロードし、識別し、誤った画像を.png形式で出力できます。"mnist-tranied.hdf5" の部分は読み込みたいモデルのファイルパスに書き換えてください。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# This script loads a trained model and outputs error images.
import gc, keras, os
import numpy as np
from keras.models import load_model
from keras.datasets import mnist
from PIL import Image, ImageDraw

IMAGE_PATH = 'error-images'
if not os.path.exists(IMAGE_PATH):
    os.mkdir(IMAGE_PATH)

# Load MNIST data for validation
(_, _), (x_val, y_val) = mnist.load_data()
x_val = (x_val.astype(np.float32) - 127.5)/127.5 # [0,255] --> [-1,1]
x_val = x_val.reshape((x_val.shape[0], x_val.shape[1], x_val.shape[2], 1))

# Predict classes for MNIST validation data
model = load_model("mnist-tranied.hdf5")
pred = model.predict_classes(x_val)

# Write error images
for i, (p, y) in enumerate(zip(pred, y_val)):
    if p != y:
        image = x_val[i]*127.5 + 127.5 # [-1,1] --> [0,255]
        img = Image.fromarray(image.reshape((image.shape[0], image.shape[1])).astype(np.uint8))
        img.save(IMAGE_PATH+"/{0}-c{1}-p{2}.png".format(i, y, p)) # c=correct, p=predict

数字フォントでモデル学習で使ったニューラルネットワークにMNISTの訓練用データを入力して学習したモデルを使って、上記のコードを走らせると、どうして認識できないのだろうという画像と、これは無理だろうという画像などが目で見て分かるようになります。

例えば、

正解は1
正解は7
正解は6

は人間でも難しいでしょう。

0 件のコメント :