はじめに
MNISTデータセットを使って拡散モデルの学習と、学習したモデルを使った画像を生成をしていきます。
これを実現するコードがすでに
https://github.com/MarceloGennari/diffusion_mnist
で公開されていましたので、こちらを利用して試します。
なお、拡散モデルについての説明は検索するとたくさん出てきますので理論的背景については論文や解説記事を参照ください。
本記事では具体的に何をすれば拡散モデルを動かせるのかを見ていきます。
元論文はarXiv:2006.11239です。
まずは動かす
ソースコードを
git cloneで取得します。
ここでは、コミットハッシュが
df15ee746aのものを使用します。
README.mdを読んで、必要なモジュールをpipでインストールしておきます。
この時点では学習も生成もCPUで実行するように設定されているため、GPUで処理するように変更した後に実行します。
学習
学習は
main.pyで実行できるのですが、この中の
device = "cpu"
と書かれている行を
device = "cuda"
に書き換えます。そして
$ python main.py
を実行すると、学習が始まります。しばらく待っていると完了し、モデルのパラメータが
unet_mnist.pthに記録されます。
生成
こちらもGPUで動くように変更します。また、学習はConditionalUNetで行われるものの生成はUNetになっているため、その点も修正します。
修正するコードはinference_unet.pyです。
先ほどと同じように
- device = "cpu"
+ device = "cuda"
に書き換えます。-が書き換え対象の行で、+が書き換えたあとの行の内容です。さらに、
- from models import UNet
+ from models import UNet, ConditionalUNet
と書き換え、
- model = UNet().to(device)
+ model = ConditionalUNet().to(device)
と書き換えます。ConditionalUNetにすると、生成時にどの数字を生成するかを指定する必要があるため、
model.eval()
+ labels = torch.randint(0, 10, (batch_size,))
+ print(labels)
+ labels_cpu = labels
+ labels = labels.to(device)
with torch.no_grad():
for t in trange(999, -1, -1):
time = torch.ones(batch_size) * t
- et = model(xt.to(device), time.to(device)) # predict noise
+ et = model(xt.to(device), time.to(device), labels) # predict noise
xt = process.inverse(xt, et.cpu(), t)
のように、書き換えます。ランダムに0〜9の値をラベルとして指定するようにしています。
出力部分を少し書き換えて
labels = ["Generated Images"] * 9
-
- for i in range(9):
- plt.subplot(3, 3, i + 1)
- plt.tight_layout()
- plt.imshow(xt[i][0], cmap="gray", interpolation="none")
- plt.title(labels[i])
- plt.show()
+ if t % 10 == 0:
+ plt.figure(figsize=(10, 10))
+ for i in range(25):
+ plt.subplot(5, 5, i + 1)
+ plt.tight_layout()
+ plt.imshow(xt[i][0], cmap="gray", interpolation="none")
+ plt.title(f"{labels_cpu[i]}")
+ plt.savefig(f"images/generated_t{t}.png")
+ #plt.show()
とすると、逆拡散過程によって少しずつ数字が画像として浮かび上がるところを見ることができます。
ただし、このようにすると、生成した数字のpyplotでの描画処理のために生成時間が延びます。GPU使用率もあきらかに低下します。
途中結果を見る必要がなければ最初のコードのほうが良いでしょう。
生成結果
結果は次のようになります。ただし、乱数がいろいろなところで使われており、シードも固定されていないので、毎回結果は異なります。
[t=990] ノイズだらけで何も読み取れません。
[t=700] そこはかとなく数字があるように見えるような見えないような。

[t=500] 遠くからみれば(ぼかしてみれば)、数字が簡単に読み取れます。

[t=300] ざらざらしていますが、十分に読めるようになりました。

[t=0] 完全にノイズが取り除かれました。アンチエイリアスは残ったままです。

学習時の処理
学習時の処理がどのように実装されているのかを見ていきます。
上位ループ
モデルの学習をする
main.pyの主要な処理である学習のループ部分を抜き出して、疑似コードとして書き換えると、
for epoch in range(100):
for image, label in 学習用画像とそのラベルの集合:
# 画像に加えるノイズの強さをランダムに選択
t = torch.randint(0, 1000, (image.shape[0],))
# 画像に加えるノイズを作成
epsilon = torch.randn(image.shape)
# tとepsilonに基づいてノイズをimageに加える
diffused_image = process.forward(image, t, epsilon)
# modelを使って加えられたノイズを予測
optimizer.zero_grad()
output = model(diffused_image, t, label)
# 予測したノイズがどれだけ正しいかを評価して、モデルの重みを更新
loss = criterion(epsilon, output) # criterion = torch.nn.MSELoss()
loss.backward()
optimizer.step()
となります。ノイズが加えられた画像から、加えられたノイズを推定するモデルを学習しているだけです。
それ以外は通常のニューラルネットワークの学習と変わりがありません。
そしてこれは元論文のAlgorithm 1の通りの実装です。5行目の
∇θ‖ϵ−ϵθ(√¯αtx0+√1−ˉαtϵ,t)‖2
と見比べると、
ϵはコード上の
epsilonに、
ϵθはコード上の
modelに、
√¯αtx0+√1−ˉαtϵは
diffused_imageに相当していることが分かります。
Algorithm 1とは異なり、modelの引数にlabelが余計についていますが、これは生成する数字が0〜9のどれであるかをニューラルネットワークに指示するための値となります。
ノイズ付加処理
上位ループ内の
diffused_image = process.forward(image, t, epsilon)
の部分について、詳細を見ていきます。
まず、diffusion_model.pyのDiffusionProcessの初期化処理部分である__init__にてαtの値が計算されています。
最初にβt (コード上ではvariance_schedule) を
self.variance_schedule = torch.linspace(1e-4, 0.01, steps=1000)
のように計算しています。具体的には
self.variance_schedule = [1.0000e-04, 1.0991e-04, 1.1982e-04, ... , 0.009980, 0.009990, 0.01]
となっています。元論文では
β1=104、
βT=0.02、
T=1000と書かれているので
βTの値のみ異なっています。
αt=1−βtであるので、
self.alpha = 1 - self.variance_schedule
にて計算され、具体的な値は
self.alpha = [0.999900, 0.999890, 0.999880 , ... , 0.990020, 0.990010, 0.990000]
となります。さらに、
ˉαt=∏ts=1αsであるので、
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
にて計算され、具体的な値は
self.alpha_bar = [0.999900, 0.999790, 0.999670, ... , 0.006430, 0.006365, 0.006302]
となります。
これらの値を使って、forwardにてノイズを画像に付加します。
√¯αtx0+√1−ˉαtϵが計算できれば良いので、まず√1−ˉαtを
std_dev = torch.sqrt(1 - self.alpha_bar[time_step])
で計算します。次に
√¯αtを
mean_multiplier = torch.sqrt(self.alpha_bar[time_step])
で計算し、最後に
diffused_images = mean_multiplier * x_0 + std_dev * noise
のように足し合わせます。
x_0が
x0で、
noiseが
ϵです。
ノイズを予測するモデル
ノイズを予測するモデルには小さめのUNetが使われています。コードからグラフに書き起こすと下図のようになります。
背景が水色のボックス内にはクラス名とインスタンス化時の引数を記載しています(画像をクリックすると拡大できます)。
降りる方向と登る方向では条件付けの場所が異なっていますが、このあたりはおおまかであっても十分に動くのでしょう。
以下では、Pytorchで提供されているクラス以外のものを見ていきます。
ResConvGroupNorm
このクラスは以下のような構成になっています。残差接続のある、よく見かける畳み込み演算を使ったブロックです。
LabelEmbedding
このクラスは、0〜9のラベルをベクトルに変換し、入力値
xに埋め込む処理をします。画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。
TemporalEmbedding
このクラスは、加えるノイズの強さを表す時刻
tを入力値
xに埋め込む処理をします。こちらも画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。LabelEmbeddingと干渉しそうですが、モデルの学習時にうまく棲み分けしているのでしょう。
SinusoidalPositionEmbeddingsは
arXiv:1706.03762の
PE(pos,2i)=sin(pos/100002i/dmodel)
PE(pos,2i+1)=cos(pos/100002i/dmodel)
とほぼ同じ計算をしています。
sinと
cosの引数部分のみ取り出すと、
pos/100002i/dmodel
となり、
posを時刻
t、
dmodelを単に
dと書き換えると、
t/100002i/d
となります。iは次元方向のインデックスです。もう少し書き換えて、
t/10000i/(d/2)
i/(d/2)の範囲が[0,1]になるように、少し値は変わりますが、
t/10000i/(d/2−1)
とします。さらに、
texp(log10000i/(d/2−1))=t⋅exp(−log10000i/(d/2−1))=t⋅exp(−id2−1log10000)
と変形します。ここまで変形するとコードとの対応がとれ、
half_dim = self.dim // 2
で
d/2を計算し、
embeddings = math.log(10000) / (half_dim - 1)
で、
1d2−1log10000を計算し、続いて
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
で、
exp(−id2−1log10000)を計算していることが分かります。
なお、変形前の1/10000i/(d/2−1)にしたがって
1/torch.pow(10000, torch.arange(half_dim, device=device)/(half_dim-1))
で計算しても、結果はほとんど変わりません。具体的には、最初の5つの値を出力すると、
式変形前 [1.0000000000, 0.9821373820, 0.9645937681, 0.9473634958, 0.9304410219]
式変形後 [1.0000000000, 0.9821373224, 0.9645937085, 0.9473634958, 0.9304410219]
となるので、どちらで計算してもよさそうです。
さて、時刻tをかけると
embeddings = time[:, None] * embeddings[None, :]
となり、最後に、
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
のように
sinと
cosを通し、それらを連結することで、埋め込みベクトルが作成できます。
LinearAttention
attention1のクラスであるLinearAttentionの実装のアイデアの元と思われる論文は
のあたりのようですが、ぴったり当てはまるものを見つけることはできませんでした。
注意機構部分のみをコードから計算式に戻すと、
dt′f′=1√3217×7∑fexp(qt′f)∑f″exp(qt′f″)∑texp(ktf)∑t″exp(kt″f)vtf′
となっているようです。普通の注意機構とは異なり、最初にKの特徴次元数×Vの特徴次元数(今の場合は32×32)の行列を計算しているようです。
1√32の32はKの特徴次元数で、arXiv:1706.03762の式(1)である
Attention(Q,K,V)=softmax(QKT√dk)V
のdkに相当するようです。
生成時の処理
生成時の処理がどのように実装されているのかを見ていきます。
上位ループ
生成処理をしている
inference_unet.pyの主要部分を疑似コードとして書き換えると、
for t in trange(999, -1, -1):
# 時刻tとラベルlabelの条件の下、ノイズを推定する
et = model(xt, t, label)
# 推定したノイズを使って画像からノイズを除去する
xt = process.inverse(xt, et, t)
# xtを画像として描画する
draw(xt)
となります。モデルの学習時は時刻
tをランダムに選んでいましたが、
生成するときはノイズから徐々にノイズを取り除いていきます。t=999ではノイズのみ、t=0ではノイズがなくなった画像になります。
ノイズ除去処理
上位ループ内の
xt = process.inverse(xt, et, t)
の部分について詳細を見ていきます。
元論文arXiv:2006.11239のAlgorithm 2の4行目の処理
xt−1=1√αt(xt−1−αt√1−ˉαtϵθ(xt,t))+σtz
が実装されています。DiffusionProcessのinverseを見ていくと、
scale = 1 / torch.sqrt(self.alpha[t])
では
1√αtの部分が計算されています。
noise_scale = (1 - self.alpha[t]) / torch.sqrt(1 - self.alpha_bar[t])
では
1−αt√1−ˉαtの部分が計算されています。
std_dev = torch.sqrt(self.variance_schedule[t])
では
σtが計算されています。
z = torch.randn(xt.shape) if t > 1 else torch.Tensor([0])
では
z∼N(0,I)が計算されています。ガウス分布に従うノイズを作っているだけです。
最後にすべてをつなげて
mu_t = scale * (xt - noise_scale * et)
xt = mu_t + std_dev * z # remove noise from image
を計算することで、Algorithm 2の4行目の処理を実現しています。