はじめに
Fréchet Inception Distance (FID)と呼ばれる、Generative Adversarial Network (GAN)が生成する画像の品質を評価する指標を試してみます[1]。この指標は、画像の集合間の距離を表します。前回試したInception Scoreは画像の集合そのものの良さを表すスコアでしたので1つ画像の集合を与えるだけで計算できましたが、FIDはそのようには計算できません。GANで再現したい真の分布から生成された画像の集合と、GANで再現した分布から生成した画像の集合との距離を計算することになります。距離が近ければ近いほど良い画像であると判断します。FIDは、Google Brainが実施したGANの大規模評価の評価指標にも用いられています[2]。
計算方法
FIDは、Inceptionモデルの途中の層の出力から得られるベクトルhを使ってFréchet Distance [3, 4]を計算することで求められます。Fréchet Distanceは曲線同士の距離のため、hのままでは距離を計算できません。そこで、画像から得られるベクトルhの分布が多変量正規分布(Multivariate normal distribution)に従うと仮定します。多変量正規分布は曲線なので、2つの多変量正規分布を求めると、その分布間のFréchet Distanceを計算できます。平均ベクトルと共分散行列が分かっている多変量正規分布間のFréchet Distanceは[5]で計算できます。
具体的な計算方法は[1]の著者らが[6]で公開しています。TensorFlowベースです。画像の集合をA、その要素をa∈Aとし、Inceptionモデルの途中の層まで計算する関数をfinception:A→Hとします。Hはhの集合です。まず、平均ベクトルμと共分散行列Σを計算します。Hの各要素はfinceptionで計算済みであるとしています。 μ=1|A|∑h∈Hh Σ=1|A|−1∑h∈H(h−μ)(h−μ)T Hから推定した分布間の距離を計算するので、[6]に倣って、Σを不偏共分散行列として計算しています。少々距離が短くなりますが、標本共分散行列で計算してもGANで生成した画像の評価指標として使う分には特に問題ないでしょう。
2つの画像集合A1とA2の距離を計算したいので、それぞれの平均ベクトルをμ1,μ2、共分散行列をΣ1,Σ2とすると、Fréchet Distanceは d2=|μ1−μ2|2+tr(Σ1+Σ2−2(Σ1Σ2)12) で計算できます。
[1]のp.7 L.13-14によると、hには最後のプーリング層を使っているとのことです。これは、Inception-v3の論文[7]のTable 1の下から3行目のpoolのことを指しています。この層の出力は1x1x2048なので、hは2048次元のベクトルということになります。[8]にも書いてありますが、画像のサンプル数が2048個より多くないとd2が計算できないので、注意が必要です(ちょっと試すだけでも2000枚強の画像が必要とは、なんて面倒な指標なんだ!)。
実験結果
MNISTとImageNetの一部の画像を使ってFIDを計算します。MNISTは、訓練用と検証用ともに先頭3000枚を利用します。ImageNetは、10クラスからランダムに2956枚選んだ画像と、6クラスから順に2956枚に達するまで選んだ画像を利用します。
10クラスは、具体的には{n02066245, n02096294, n02100735, n02119789, n02123394, n02124075, n02125311, n02417914, n02423022, n02509815}です。6クラスは{n02066245, n02096294, n02100735, n02119789, n02123394, n02124075}です。
計算結果は下表の通りです。
A1 | A2 | FID (d2) |
---|---|---|
MNIST train | MNIST train | 6.8959054997e-11 |
MNIST train | MNIST val | 7.05136659744 |
MNIST val | MNIST train | 7.0513665973 |
MNIST train | ImageNet 10 classes | 338.586062646 |
MNIST train | ImageNet 6 classes | 346.218188602 |
ImageNet 10 classes | ImageNet 6 classes | 67.1025959331 |
距離の大きさはともかく、その大小関係は期待通りになっています。具体的には、
- 同じ集合間の距離は0。
- クラスが同じ画像集合間の距離は、クラスが多少異なる画像集合間より近い。
- 全く異なる画像集合間の距離は、クラスが多少異なる画像集合間の距離より遠い。
コード
今回の実験に使ったコードは以下です。Kerasを使っています。実験の都合上、計算したHは一旦ファイルに保存するようにしています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | # -*- coding: utf-8 -*-
import os, glob
import glob
import numpy as np
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.applications.imagenet_utils import decode_predictions
from keras.preprocessing import image
from keras.datasets import mnist
from keras.models import Model
from PIL import Image as pil_image
from scipy.linalg import sqrtm
model = InceptionV3() # Load a model and its weights
model4fid = Model(inputs=model.input, outputs=model.get_layer("avg_pool").output)
def resize_mnist(x):
x_list = []
for i in range(x.shape[0]):
img = image.array_to_img(x[i, :, :, :].reshape(28, 28, -1))
#img.save("mnist-{0:03d}.png".format(i))
img = img.resize(size=(299, 299), resample=pil_image.LANCZOS)
x_list.append(image.img_to_array(img))
return np.array(x_list)
def resize_do_nothing(x):
return x
def frechet_distance(m1, c1, m2, c2):
return np.sum((m1 - m2)**2) + np.trace(c1 + c2 - 2*(sqrtm(np.dot(c1, c2))))
def mean_cov(x):
mean = np.mean(x, axis=0)
sigma = np.cov(x, rowvar=False)
return mean, sigma
def fid(h1, h2):
m1, c1 = mean_cov(h1)
m2, c2 = mean_cov(h2)
return frechet_distance(m1, c1, m2, c2)
def calc_h(x, resizer, batch_size=8):
r = None
n_batch = (x.shape[0]+batch_size-1) // batch_size
for j in range(n_batch):
x_batch = resizer(x[j*batch_size:(j+1)*batch_size, :, :, :])
r_batch = model4fid.predict(preprocess_input(x_batch))
r = r_batch if r is None else np.concatenate([r, r_batch], axis=0)
if j % 10 == 0:
print("i =", j)
return r
def mnist_h(n_train, n_val):
x = [0, 0]; h = [0, 0]; n = [n_train, n_val]
(x[0], _), (x[1], _) = mnist.load_data()
for i in range(2):
x[i] = np.expand_dims(x[i], axis=3) # shape=(60000, 28, 28) --> (60000, 28, 28, 1)
x[i] = np.tile(x[i], (1, 1, 1, 3)) # shape=(60000, 28, 28, 1) --> (60000, 28, 28, 3)
h[i] = calc_h(x[i][0:n[i], :, :, :], resize_mnist)
return h[0], h[1]
def imagenet_h(files, batch_size=8):
xs = []; hs = []
for f in files:
img = image.load_img(f, target_size=(299, 299))
x = image.img_to_array(img) # x.shape=(299, 299, 3)
xs.append(x)
if len(xs) == batch_size:
hs.append(calc_h(np.array(xs), resize_do_nothing))
xs = []
if len(xs) > 0:
hs.append(calc_h(np.array(xs), resize_do_nothing))
return np.concatenate(hs, axis=0)
# Calculate and save H of MNIST
h_train, h_val = mnist_h(3000, 3000)
np.save("mnist_h_train.npy", h_train)
np.save("mnist_h_val.npy", h_val)
# Calculate and save H of the part of Imagenet
h_imagenet = imagenet_h(glob.glob("from_imagenet/*.jpg")) # 10 classes
h_imagenet_seq = imagenet_h(sorted(glob.glob("from_imagenet_seq/*.jpg"))[0:2956]) # 6 classes
np.save("imagenet_h.npy", h_imagenet)
np.save("imagenet_h_seq.npy", h_imagenet_seq)
# Load H and calculate FID
h_train = np.load("mnist_h_train.npy")
h_val = np.load("mnist_h_val.npy")
h_imagenet = np.load("imagenet_h.npy")
h_imagenet_seq = np.load("imagenet_h_seq.npy")
print("FID between MNIST train and val :", fid(h_train, h_val))
print("FID between MNIST val and train :", fid(h_val, h_train))
print("FID between MNIST train and train :", fid(h_train, h_train))
print("FID between MNIST train and imagenet :", fid(h_train, h_imagenet))
print("FID between MNIST train and imagenet_seq :", fid(h_train, h_imagenet_seq))
print("FID between imagenet and imagenet_seq :", fid(h_imagenet, h_imagenet_seq))
|
参考
[1] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter, "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium," arXiv:1706.08500, 2017, https://arxiv.org/abs/1706.08500
[2] Mario Lucic, Karol Kurach, Marcin Michalski, Sylvain Gelly, Olivier Bousquet, "Are GANs Created Equal? A Large-Scale Study," arXiv:1711.10337, 2017, https://arxiv.org/abs/1711.10337
[3] Fréchet, M. "Sur la distance de deux lois de probabilité," C. R. Acad. Sci. Paris, 244, 689-692, 1957 (内容を確認したわけではない)
[4] http://www.thothchildren.com/chapter/59b4f81975704408bd430061 (Fréchet Distanceの解説記事)
[5] D. C. Dowson and B. V. Landau, "The Fréchet Distance between Multivariate Normal Distributions," Journal of multivariate analysis, 12, 450-455, 1982, http://www.sciencedirect.com/science/article/pii/0047259X8290077X
[6] https://github.com/bioinf-jku/TTUR/blob/master/fid.py ([1]の著者らが作成したコード)
[7] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna, "Rethinking the Inception Architecture for Computer Vision," arXiv:1512.00567, 2015, https://arxiv.org/abs/1512.00567