2017年12月24日日曜日

Fréchet Inception Distance

はじめに

Fréchet Inception Distance (FID)と呼ばれる、Generative Adversarial Network (GAN)が生成する画像の品質を評価する指標を試してみます[1]。この指標は、画像の集合間の距離を表します。前回試したInception Scoreは画像の集合そのものの良さを表すスコアでしたので1つ画像の集合を与えるだけで計算できましたが、FIDはそのようには計算できません。GANで再現したい真の分布から生成された画像の集合と、GANで再現した分布から生成した画像の集合との距離を計算することになります。距離が近ければ近いほど良い画像であると判断します。FIDは、Google Brainが実施したGANの大規模評価の評価指標にも用いられています[2]

計算方法

FIDは、Inceptionモデルの途中の層の出力から得られるベクトル\(h\)を使ってFréchet Distance [3, 4]を計算することで求められます。Fréchet Distanceは曲線同士の距離のため、\(h\)のままでは距離を計算できません。そこで、画像から得られるベクトル\(h\)の分布が多変量正規分布(Multivariate normal distribution)に従うと仮定します。多変量正規分布は曲線なので、2つの多変量正規分布を求めると、その分布間のFréchet Distanceを計算できます。平均ベクトルと共分散行列が分かっている多変量正規分布間のFréchet Distanceは[5]で計算できます。

具体的な計算方法は[1]の著者らが[6]で公開しています。TensorFlowベースです。画像の集合を\(A\)、その要素を\(a \in A\)とし、Inceptionモデルの途中の層まで計算する関数を\(f_{\rm inception} : A \rightarrow H\)とします。\(H\)は\(h\)の集合です。まず、平均ベクトル\(\mu\)と共分散行列\(\Sigma\)を計算します。\(H\)の各要素は\(f_{\rm inception}\)で計算済みであるとしています。 \[ \mu = \frac{1}{|A|} \sum_{h \in H} h \] \[ \Sigma = \frac{1}{|A|-1} \sum_{h \in H} (h-\mu)(h-\mu)^{T} \] \(H\)から推定した分布間の距離を計算するので、[6]に倣って、\(\Sigma\)を不偏共分散行列として計算しています。少々距離が短くなりますが、標本共分散行列で計算してもGANで生成した画像の評価指標として使う分には特に問題ないでしょう。

2つの画像集合\(A_1\)と\(A_2\)の距離を計算したいので、それぞれの平均ベクトルを\(\mu_1, \mu_2\)、共分散行列を\(\Sigma_1, \Sigma_2\)とすると、Fréchet Distanceは \[ d^2 = |\mu_1-\mu_2|^2 + {\rm tr}\left (\Sigma_1 + \Sigma_2 - 2(\Sigma_1 \Sigma_2)^{\frac{1}{2}}\right )\] で計算できます。

[1]のp.7 L.13-14によると、\(h\)には最後のプーリング層を使っているとのことです。これは、Inception-v3の論文[7]のTable 1の下から3行目のpoolのことを指しています。この層の出力は1x1x2048なので、\(h\)は2048次元のベクトルということになります。[8]にも書いてありますが、画像のサンプル数が2048個より多くないと\(d^2\)が計算できないので、注意が必要です(ちょっと試すだけでも2000枚強の画像が必要とは、なんて面倒な指標なんだ!)。

実験結果

MNISTとImageNetの一部の画像を使ってFIDを計算します。MNISTは、訓練用と検証用ともに先頭3000枚を利用します。ImageNetは、10クラスからランダムに2956枚選んだ画像と、6クラスから順に2956枚に達するまで選んだ画像を利用します。

10クラスは、具体的には{n02066245, n02096294, n02100735, n02119789, n02123394, n02124075, n02125311, n02417914, n02423022, n02509815}です。6クラスは{n02066245, n02096294, n02100735, n02119789, n02123394, n02124075}です。

計算結果は下表の通りです。

\(A_1\)\(A_2\)FID (\(d^2\))
MNIST trainMNIST train6.8959054997e-11
MNIST trainMNIST val7.05136659744
MNIST valMNIST train7.0513665973
MNIST trainImageNet 10 classes338.586062646
MNIST trainImageNet 6 classes346.218188602
ImageNet 10 classesImageNet 6 classes67.1025959331
同じ画像集合間の距離は計算誤差を無視すると0になっています(1行目)。MNISTのtrainとvalの距離は7.05です(2, 3行目)。MNISTとImageNetの距離はMNIST間の距離より遠く、約340です(4, 5行目)。クラス数が異なるImageNetの画像集合間の距離は67.1で、MNISTのtrainとvalの間の距離より遠くなっています(6行目)。

距離の大きさはともかく、その大小関係は期待通りになっています。具体的には、

  • 同じ集合間の距離は0。
  • クラスが同じ画像集合間の距離は、クラスが多少異なる画像集合間より近い。
  • 全く異なる画像集合間の距離は、クラスが多少異なる画像集合間の距離より遠い。
ということです。

コード

今回の実験に使ったコードは以下です。Kerasを使っています。実験の都合上、計算した\(H\)は一旦ファイルに保存するようにしています。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
import os, glob
import glob
import numpy as np
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.applications.imagenet_utils import decode_predictions
from keras.preprocessing import image
from keras.datasets import mnist
from keras.models import Model
from PIL import Image as pil_image
from scipy.linalg import sqrtm

model = InceptionV3() # Load a model and its weights
model4fid = Model(inputs=model.input, outputs=model.get_layer("avg_pool").output)
def resize_mnist(x):
    x_list = []
    for i in range(x.shape[0]):
        img = image.array_to_img(x[i, :, :, :].reshape(28, 28, -1))
        #img.save("mnist-{0:03d}.png".format(i))
        img = img.resize(size=(299, 299), resample=pil_image.LANCZOS)
        x_list.append(image.img_to_array(img))
    return np.array(x_list)

def resize_do_nothing(x):
    return x

def frechet_distance(m1, c1, m2, c2):
    return np.sum((m1 - m2)**2) + np.trace(c1 + c2 - 2*(sqrtm(np.dot(c1, c2))))

def mean_cov(x):
    mean = np.mean(x, axis=0)
    sigma = np.cov(x, rowvar=False)
    return mean, sigma

def fid(h1, h2):
    m1, c1 = mean_cov(h1)
    m2, c2 = mean_cov(h2)
    return frechet_distance(m1, c1, m2, c2)

def calc_h(x, resizer, batch_size=8):
    r = None
    n_batch = (x.shape[0]+batch_size-1) // batch_size
    for j in range(n_batch):
        x_batch = resizer(x[j*batch_size:(j+1)*batch_size, :, :, :])
        r_batch = model4fid.predict(preprocess_input(x_batch))
        r = r_batch if r is None else np.concatenate([r, r_batch], axis=0)
        if j % 10 == 0:
            print("i =", j)
    return r

def mnist_h(n_train, n_val):
    x = [0, 0]; h = [0, 0]; n = [n_train, n_val]
    (x[0], _), (x[1], _) = mnist.load_data()
    for i in range(2):
        x[i] = np.expand_dims(x[i], axis=3) # shape=(60000, 28, 28) --> (60000, 28, 28, 1)
        x[i] = np.tile(x[i], (1, 1, 1, 3)) # shape=(60000, 28, 28, 1) --> (60000, 28, 28, 3)
        h[i] = calc_h(x[i][0:n[i], :, :, :], resize_mnist)
    return h[0], h[1]

def imagenet_h(files, batch_size=8):
    xs = []; hs = []
    for f in files:
        img = image.load_img(f, target_size=(299, 299))
        x = image.img_to_array(img) # x.shape=(299, 299, 3)
        xs.append(x)
        if len(xs) == batch_size:
            hs.append(calc_h(np.array(xs), resize_do_nothing))
            xs = []
    if len(xs) > 0:
        hs.append(calc_h(np.array(xs), resize_do_nothing))
    return np.concatenate(hs, axis=0)

# Calculate and save H of MNIST
h_train, h_val = mnist_h(3000, 3000)
np.save("mnist_h_train.npy", h_train)
np.save("mnist_h_val.npy", h_val)

# Calculate and save H of the part of Imagenet 
h_imagenet = imagenet_h(glob.glob("from_imagenet/*.jpg")) # 10 classes
h_imagenet_seq = imagenet_h(sorted(glob.glob("from_imagenet_seq/*.jpg"))[0:2956]) # 6 classes
np.save("imagenet_h.npy", h_imagenet)
np.save("imagenet_h_seq.npy", h_imagenet_seq)

# Load H and calculate FID
h_train = np.load("mnist_h_train.npy")
h_val = np.load("mnist_h_val.npy")
h_imagenet = np.load("imagenet_h.npy")
h_imagenet_seq = np.load("imagenet_h_seq.npy")
print("FID between MNIST train and val :", fid(h_train, h_val))
print("FID between MNIST val and train :", fid(h_val, h_train))
print("FID between MNIST train and train :", fid(h_train, h_train))
print("FID between MNIST train and imagenet :", fid(h_train, h_imagenet))
print("FID between MNIST train and imagenet_seq :", fid(h_train, h_imagenet_seq))
print("FID between imagenet and imagenet_seq :", fid(h_imagenet, h_imagenet_seq))

参考

[1] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter, "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium," arXiv:1706.08500, 2017, https://arxiv.org/abs/1706.08500
[2] Mario Lucic, Karol Kurach, Marcin Michalski, Sylvain Gelly, Olivier Bousquet, "Are GANs Created Equal? A Large-Scale Study," arXiv:1711.10337, 2017, https://arxiv.org/abs/1711.10337
[3] Fréchet, M. "Sur la distance de deux lois de probabilité," C. R. Acad. Sci. Paris, 244, 689-692, 1957 (内容を確認したわけではない)
[4] http://www.thothchildren.com/chapter/59b4f81975704408bd430061 (Fréchet Distanceの解説記事)
[5] D. C. Dowson and B. V. Landau, "The Fréchet Distance between Multivariate Normal Distributions," Journal of multivariate analysis, 12, 450-455, 1982, http://www.sciencedirect.com/science/article/pii/0047259X8290077X
[6] https://github.com/bioinf-jku/TTUR/blob/master/fid.py ([1]の著者らが作成したコード)
[7] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna, "Rethinking the Inception Architecture for Computer Vision," arXiv:1512.00567, 2015, https://arxiv.org/abs/1512.00567

0 件のコメント :