2023/10/26

Diffusion MNIST その3

はじめに


その1で試したDiffusion MNISTについて、 ノイズの強さを表す時刻\(t\)をニューラルネットワークに伝えないとどうなるのかを見ていきます。

方法


https://github.com/MarceloGennari/diffusion_mnist のConditionalUNetの\(t\)が関連する行、つまり、TemporalEmbedding部分をコメントアウトします。具体的には
class ConditionalUNet(UNet):
    (省略)
    def forward(self, x: Tensor, t: Tensor, label: Tensor) -> Tensor:
        x0 = x #self.embedding1(x, t)
        x1 = self.block1(x0)
        x1 = self.label_emb1(x1, label)
        #x1 = self.embedding2(x1, t)
        x2 = self.block2(self.down1(x1))
        x2 = self.label_emb2(x2, label)
        #x2 = self.embedding3(x2, t)
        crossed = self.label_emb3(self.block3(self.down2(x2)), label)
        x3 = self.up1(self.attention1(crossed))
        x4 = torch.cat([x2, x3], dim=1)
        #x4 = self.embedding4(x4, t)
        x5 = self.up2(self.label_emb4(self.block4(x4), label))
        x6 = torch.cat([x5, x1], dim=1)
        x6 = self.label_emb5(x6, label)
        #x6 = self.embedding5(x6, t)
        out = self.out(self.block5(x6))
        return out
とします。

結果


その2で試した時間刻みを100にしたバージョンをベースに比較します。左側がTemporalEmbeddingありで、右側がなしに対応します。
t=50 TemporalEmbeddingあり
t=50 TemporalEmbeddingなし

t=0 TemporalEmbeddingあり
t=0 TemporalEmbeddingなし

TemporalEmbeddingなしの場合はノイズが多いように見えるので、画像の明るさをGIMPを使って上げたものが下図です。TemporalEmbeddingありではノイズが見えませんが、TemporalEmbeddingなしではノイズがはっきり見えるケースが多くなっています。
t=0 TemporalEmbeddingあり
t=0 TemporalEmbeddingなし

まとめ


時刻の埋め込みは効果があるということを確認できました。

2023/10/22

Diffusion MNIST その2

はじめに


その1で試したDiffusion MNISTについて、 ノイズを乗せるステップの細かさを粗くするとどうなるのかを見てみます。

方法


https://github.com/MarceloGennari/diffusion_mnist をいくつか変更することで粗さを変えていきます。

スケジュール変更


DiffusionProcessの引数に渡すvariance_scheduleを変えていきます。 デフォルトでは、
variance_schedule = torch.linspace(1e-4, 0.01, steps=1000)
となっています。これをパターンAでは
variance_schedule = torch.linspace(1e-4, 0.1, steps=100)
と、パターンBでは
variance_schedule = torch.linspace(1e-4, 0.999, steps=10)
とします。

それぞれのスケジュールを使ったときのalphaは

[パターン デフォルト]
[0.99990, 0.99989, 0.99988, ... , 0.99001, 0.99000]

[パターン A]
[0.99990, 0.99889, 0.99788, ... , 0.90101, 0.90000]

[パターン B]
[0.99990, 0.88891, 0.77792, ... , 0.11199, 0.00100]
となります。

alpha_barは

[パターン デフォルト]
[0.9999, 0.9998, 0.9997, ... , 0.0064, 0.0063]

[パターン A]
[0.9999, 0.9988, 0.9967, ... , 0.0062, 0.0056]

[パターン B]
[9.9990e-01, 8.8882e-01, 6.9143e-01, ... , 9.5131e-04, 9.5130e-07]
となります。ここで重要なことは、最初の時刻(ノイズが乗っていない)をt=0、最後の時刻(完全にノイズ)をt=1とするとき、alpha_barはt=0では1に近く、t=1では0に近くなるようにvariance_scheduleを決める必要があるということです。 各時刻tにおけるノイズの強さがalpha_barで決まり、t=1のときに完全にノイズになっていないと拡散プロセスの前提が崩れてしまうためです。

実際、パターンAの

variance_schedule = torch.linspace(1e-4, 0.1, steps=100)
variance_schedule = torch.linspace(1e-4, 0.01, steps=100)
に変えると、alpha_bar の値は
0.99990, 0.99970, 0.99940, ... , 0.60857, 0.60248
となりますが、この場合、数字の画像をうまく生成できません。

学習時のtの値


デフォルトではmain.py
t = torch.randint(0, 1000, (image.shape[0],))
の1000のところを、パターンAでは100に、パターンBでは10にします。

生成時のtの値


デフォルトではinference_unet.py
for t in trange(999, -1, -1):
の999のところを、パターンAでは99に、パターンBでは9にします。刻む数が少なくなると(ステップの細かさを粗くすると)、その分だけ生成時間を短くできます。

結果


デフォルトの設定ではこのようになります(その1の再掲)
t=500
t=0

時刻を100個に刻んだパターンAでも特に変わりなく生成できています。
t=50
t=0

時刻を10個に刻んだパターンBだと、多少ノイズが残ってしまいますが、生成できないというほどではありません。
t=5
t=0

まとめ


デフォルトの1000ステップではなくても、MNIST程度なら生成できることが分かりました。

2023/10/15

Diffusion MNIST

はじめに


MNISTデータセットを使って拡散モデルの学習と、学習したモデルを使った画像を生成をしていきます。

これを実現するコードがすでに
https://github.com/MarceloGennari/diffusion_mnist
で公開されていましたので、こちらを利用して試します。

なお、拡散モデルについての説明は検索するとたくさん出てきますので理論的背景については論文や解説記事を参照ください。 本記事では具体的に何をすれば拡散モデルを動かせるのかを見ていきます。

元論文はarXiv:2006.11239です。

まずは動かす


ソースコードをgit cloneで取得します。 ここでは、コミットハッシュがdf15ee746aのものを使用します。 README.mdを読んで、必要なモジュールをpipでインストールしておきます。

この時点では学習も生成もCPUで実行するように設定されているため、GPUで処理するように変更した後に実行します。

学習


学習はmain.pyで実行できるのですが、この中の
device = "cpu"
と書かれている行を
device = "cuda"
に書き換えます。そして
$ python main.py
を実行すると、学習が始まります。しばらく待っていると完了し、モデルのパラメータがunet_mnist.pthに記録されます。

生成


こちらもGPUで動くように変更します。また、学習はConditionalUNetで行われるものの生成はUNetになっているため、その点も修正します。

修正するコードはinference_unet.pyです。 先ほどと同じように

- device = "cpu"
+ device = "cuda"
に書き換えます。-が書き換え対象の行で、+が書き換えたあとの行の内容です。さらに、
- from models import UNet
+ from models import UNet, ConditionalUNet
と書き換え、
- model = UNet().to(device)
+ model = ConditionalUNet().to(device)
と書き換えます。ConditionalUNetにすると、生成時にどの数字を生成するかを指定する必要があるため、
     model.eval()
+    labels = torch.randint(0, 10, (batch_size,))
+    print(labels)
+    labels_cpu = labels
+    labels = labels.to(device)
     with torch.no_grad():
         for t in trange(999, -1, -1):
             time = torch.ones(batch_size) * t
-            et = model(xt.to(device), time.to(device))  # predict noise
+            et = model(xt.to(device), time.to(device), labels)  # predict noise
             xt = process.inverse(xt, et.cpu(), t)
のように、書き換えます。ランダムに0〜9の値をラベルとして指定するようにしています。

出力部分を少し書き換えて

    labels = ["Generated Images"] * 9
-
-    for i in range(9):
-        plt.subplot(3, 3, i + 1)
-        plt.tight_layout()
-        plt.imshow(xt[i][0], cmap="gray", interpolation="none")
-        plt.title(labels[i])
-    plt.show()
+            if t % 10 == 0:
+                plt.figure(figsize=(10, 10))
+                for i in range(25):
+                    plt.subplot(5, 5, i + 1)
+                    plt.tight_layout()
+                    plt.imshow(xt[i][0], cmap="gray", interpolation="none")
+                    plt.title(f"{labels_cpu[i]}")
+                plt.savefig(f"images/generated_t{t}.png")
+    #plt.show()
とすると、逆拡散過程によって少しずつ数字が画像として浮かび上がるところを見ることができます。 ただし、このようにすると、生成した数字のpyplotでの描画処理のために生成時間が延びます。GPU使用率もあきらかに低下します。 途中結果を見る必要がなければ最初のコードのほうが良いでしょう。

生成結果


結果は次のようになります。ただし、乱数がいろいろなところで使われており、シードも固定されていないので、毎回結果は異なります。

[t=990] ノイズだらけで何も読み取れません。

[t=700] そこはかとなく数字があるように見えるような見えないような。

[t=500] 遠くからみれば(ぼかしてみれば)、数字が簡単に読み取れます。

[t=300] ざらざらしていますが、十分に読めるようになりました。

[t=0] 完全にノイズが取り除かれました。アンチエイリアスは残ったままです。

学習時の処理


学習時の処理がどのように実装されているのかを見ていきます。

上位ループ


モデルの学習をするmain.pyの主要な処理である学習のループ部分を抜き出して、疑似コードとして書き換えると、
for epoch in range(100):
    for image, label in 学習用画像とそのラベルの集合:
	# 画像に加えるノイズの強さをランダムに選択
        t = torch.randint(0, 1000, (image.shape[0],))
        
        # 画像に加えるノイズを作成
        epsilon = torch.randn(image.shape)
        
        # tとepsilonに基づいてノイズをimageに加える
        diffused_image = process.forward(image, t, epsilon)

        # modelを使って加えられたノイズを予測
        optimizer.zero_grad()
        output = model(diffused_image, t, label)
        
        # 予測したノイズがどれだけ正しいかを評価して、モデルの重みを更新
        loss = criterion(epsilon, output) # criterion = torch.nn.MSELoss()
        loss.backward()
        optimizer.step()
となります。ノイズが加えられた画像から、加えられたノイズを推定するモデルを学習しているだけです。 それ以外は通常のニューラルネットワークの学習と変わりがありません。

そしてこれは元論文のAlgorithm 1の通りの実装です。5行目の


\( \nabla_\theta \| \bm{\epsilon} - \bm{\epsilon}_\theta ( \sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}, t) \|^2 \)

と見比べると、\(\bm{\epsilon}\)はコード上のepsilonに、 \(\bm{\epsilon_{\theta}}\)はコード上のmodelに、 \(\sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}\)はdiffused_imageに相当していることが分かります。

Algorithm 1とは異なり、modelの引数にlabelが余計についていますが、これは生成する数字が0〜9のどれであるかをニューラルネットワークに指示するための値となります。

ノイズ付加処理


上位ループ内の
diffused_image = process.forward(image, t, epsilon)
の部分について、詳細を見ていきます。

まず、diffusion_model.pyDiffusionProcessの初期化処理部分である__init__にて\(\alpha_t\)の値が計算されています。

最初に\(\beta_t\) (コード上ではvariance_schedule) を

self.variance_schedule = torch.linspace(1e-4, 0.01, steps=1000)
のように計算しています。具体的には
self.variance_schedule = [1.0000e-04, 1.0991e-04, 1.1982e-04, ... , 0.009980, 0.009990, 0.01]
となっています。元論文では\(\beta_1=10^4\)、\(\beta_T=0.02\)、\(T=1000\)と書かれているので\(\beta_T\)の値のみ異なっています。

\(\alpha_t = 1-\beta_t\)であるので、

self.alpha = 1 - self.variance_schedule
にて計算され、具体的な値は
self.alpha = [0.999900, 0.999890, 0.999880 , ... , 0.990020, 0.990010, 0.990000]
となります。さらに、\(\bar{\alpha}_t=\prod_{s=1}^t \alpha_s\)であるので、
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
にて計算され、具体的な値は
self.alpha_bar = [0.999900, 0.999790, 0.999670, ... , 0.006430, 0.006365, 0.006302]
となります。

これらの値を使って、forwardにてノイズを画像に付加します。 \(\sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}\)が計算できれば良いので、まず\(\sqrt{1-\bar{\alpha}_t}\)を

std_dev = torch.sqrt(1 - self.alpha_bar[time_step])
で計算します。次に\(\sqrt{\bar{\alpha_t}}\)を
mean_multiplier = torch.sqrt(self.alpha_bar[time_step])
で計算し、最後に
diffused_images = mean_multiplier * x_0 + std_dev * noise
のように足し合わせます。x_0が\(\bm{x}_0\)で、noiseが\(\bm{\epsilon}\)です。

ノイズを予測するモデル


ノイズを予測するモデルには小さめのUNetが使われています。コードからグラフに書き起こすと下図のようになります。
背景が水色のボックス内にはクラス名とインスタンス化時の引数を記載しています(画像をクリックすると拡大できます)。

降りる方向と登る方向では条件付けの場所が異なっていますが、このあたりはおおまかであっても十分に動くのでしょう。

以下では、Pytorchで提供されているクラス以外のものを見ていきます。

ResConvGroupNorm


このクラスは以下のような構成になっています。残差接続のある、よく見かける畳み込み演算を使ったブロックです。

LabelEmbedding


このクラスは、0〜9のラベルをベクトルに変換し、入力値\(x\)に埋め込む処理をします。画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。

TemporalEmbedding


このクラスは、加えるノイズの強さを表す時刻\(t\)を入力値\(x\)に埋め込む処理をします。こちらも画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。LabelEmbeddingと干渉しそうですが、モデルの学習時にうまく棲み分けしているのでしょう。
SinusoidalPositionEmbeddingsはarXiv:1706.03762の \[ PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}}) \] \[ PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}}) \] とほぼ同じ計算をしています。\(sin\)と\(cos\)の引数部分のみ取り出すと、 \[ pos/10000^{2i/d_{model}} \] となり、\(pos\)を時刻\(t\)、\(d_{model}\)を単に\(d\)と書き換えると、 \[ t/10000^{2i/d} \] となります。iは次元方向のインデックスです。もう少し書き換えて、 \[ t/10000^{i/(d/2)} \] \(i/(d/2)\)の範囲が[0,1]になるように、少し値は変わりますが、 \[ t/10000^{i/(d/2-1)} \] とします。さらに、 \[ \begin{aligned} \frac{t}{\exp\left(\log 10000^{i/(d/2-1)}\right)} &= t \cdot \exp \left(-\log 10000^{i/(d/2-1)} \right) \\ &= t \cdot \exp \left(-\frac{i}{\frac{d}{2}-1} \log 10000 \right) \end{aligned} \] と変形します。ここまで変形するとコードとの対応がとれ、
half_dim = self.dim // 2
で\(d/2\)を計算し、
embeddings = math.log(10000) / (half_dim - 1)
で、\(\frac{1}{\frac{d}{2}-1} \log 10000\)を計算し、続いて
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
で、\(\exp \left(-\frac{i}{\frac{d}{2}-1} \log 10000 \right)\)を計算していることが分かります。

なお、変形前の\(1/10000^{i/(d/2-1)}\)にしたがって

1/torch.pow(10000, torch.arange(half_dim, device=device)/(half_dim-1))
で計算しても、結果はほとんど変わりません。具体的には、最初の5つの値を出力すると、
式変形前 [1.0000000000, 0.9821373820, 0.9645937681, 0.9473634958, 0.9304410219]
式変形後 [1.0000000000, 0.9821373224, 0.9645937085, 0.9473634958, 0.9304410219]
となるので、どちらで計算してもよさそうです。

さて、時刻\(t\)をかけると

embeddings = time[:, None] * embeddings[None, :]
となり、最後に、
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
のように\(sin\)と\(cos\)を通し、それらを連結することで、埋め込みベクトルが作成できます。

LinearAttention


attention1のクラスであるLinearAttentionの実装のアイデアの元と思われる論文は のあたりのようですが、ぴったり当てはまるものを見つけることはできませんでした。

注意機構部分のみをコードから計算式に戻すと、 \[ d_{t'f'} = \frac{1}{\sqrt{32}}\frac{1}{7 \times 7} \sum_f{\frac{\exp(q_{t'f})}{\sum_{f''}{\exp(q_{t'f''})}} \sum_t{\frac{\exp(k_{tf})}{\sum_{t''}\exp(k_{t''f})} v_{tf'}}} \] となっているようです。普通の注意機構とは異なり、最初に\(K\)の特徴次元数×\(V\)の特徴次元数(今の場合は32×32)の行列を計算しているようです。 \(\frac{1}{\sqrt{32}}\)の32は\(K\)の特徴次元数で、arXiv:1706.03762の式(1)である \[ Attention(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \] の\(d_k\)に相当するようです。

生成時の処理


生成時の処理がどのように実装されているのかを見ていきます。

上位ループ


生成処理をしているinference_unet.pyの主要部分を疑似コードとして書き換えると、
for t in trange(999, -1, -1):
    # 時刻tとラベルlabelの条件の下、ノイズを推定する
    et = model(xt, t, label)

    # 推定したノイズを使って画像からノイズを除去する
    xt = process.inverse(xt, et, t)

# xtを画像として描画する
draw(xt)
となります。モデルの学習時は時刻tをランダムに選んでいましたが、 生成するときはノイズから徐々にノイズを取り除いていきます。t=999ではノイズのみ、t=0ではノイズがなくなった画像になります。

ノイズ除去処理


上位ループ内の
xt = process.inverse(xt, et, t)
の部分について詳細を見ていきます。

元論文arXiv:2006.11239のAlgorithm 2の4行目の処理 \[ \bm{x}_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(\bm{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\bm{x}_t,t)\right)+\sigma_t\bm{z} \] が実装されています。DiffusionProcessinverseを見ていくと、

scale = 1 / torch.sqrt(self.alpha[t])
では\(\frac{1}{\sqrt{\alpha_t}}\)の部分が計算されています。
noise_scale = (1 - self.alpha[t]) / torch.sqrt(1 - self.alpha_bar[t])
では\(\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\)の部分が計算されています。
std_dev = torch.sqrt(self.variance_schedule[t])
では\(\sigma_t\)が計算されています。
z = torch.randn(xt.shape) if t > 1 else torch.Tensor([0])
では\(\bm{z} \sim N(0,\bm{I})\)が計算されています。ガウス分布に従うノイズを作っているだけです。

最後にすべてをつなげて

mu_t = scale * (xt - noise_scale * et)
xt = mu_t + std_dev * z  # remove noise from image  
を計算することで、Algorithm 2の4行目の処理を実現しています。