はじめに
MNISTデータセットを使って拡散モデルの学習と、学習したモデルを使った画像を生成をしていきます。 これを実現するコードがすでに
https://github.com/MarceloGennari/diffusion_mnist
で公開されていましたので、こちらを利用して試します。 なお、拡散モデルについての説明は検索するとたくさん出てきますので理論的背景については論文や解説記事を参照ください。 本記事では具体的に何をすれば拡散モデルを動かせるのかを見ていきます。 元論文はarXiv:2006.11239です。
まずは動かす
ソースコードをgit cloneで取得します。 ここでは、コミットハッシュがdf15ee746aのものを使用します。 README.mdを読んで、必要なモジュールをpipでインストールしておきます。 この時点では学習も生成もCPUで実行するように設定されているため、GPUで処理するように変更した後に実行します。
学習
学習はmain.pyで実行できるのですが、この中の
device = "cpu"と書かれている行を
device = "cuda"に書き換えます。そして
$ python main.pyを実行すると、学習が始まります。しばらく待っていると完了し、モデルのパラメータがunet_mnist.pthに記録されます。
生成
こちらもGPUで動くように変更します。また、学習はConditionalUNetで行われるものの生成はUNetになっているため、その点も修正します。 修正するコードはinference_unet.pyです。 先ほどと同じように
- device = "cpu" + device = "cuda"に書き換えます。-が書き換え対象の行で、+が書き換えたあとの行の内容です。さらに、
- from models import UNet + from models import UNet, ConditionalUNetと書き換え、
- model = UNet().to(device) + model = ConditionalUNet().to(device)と書き換えます。ConditionalUNetにすると、生成時にどの数字を生成するかを指定する必要があるため、
model.eval() + labels = torch.randint(0, 10, (batch_size,)) + print(labels) + labels_cpu = labels + labels = labels.to(device) with torch.no_grad(): for t in trange(999, -1, -1): time = torch.ones(batch_size) * t - et = model(xt.to(device), time.to(device)) # predict noise + et = model(xt.to(device), time.to(device), labels) # predict noise xt = process.inverse(xt, et.cpu(), t)のように、書き換えます。ランダムに0〜9の値をラベルとして指定するようにしています。 出力部分を少し書き換えて
labels = ["Generated Images"] * 9 - - for i in range(9): - plt.subplot(3, 3, i + 1) - plt.tight_layout() - plt.imshow(xt[i][0], cmap="gray", interpolation="none") - plt.title(labels[i]) - plt.show() + if t % 10 == 0: + plt.figure(figsize=(10, 10)) + for i in range(25): + plt.subplot(5, 5, i + 1) + plt.tight_layout() + plt.imshow(xt[i][0], cmap="gray", interpolation="none") + plt.title(f"{labels_cpu[i]}") + plt.savefig(f"images/generated_t{t}.png") + #plt.show()とすると、逆拡散過程によって少しずつ数字が画像として浮かび上がるところを見ることができます。 ただし、このようにすると、生成した数字のpyplotでの描画処理のために生成時間が延びます。GPU使用率もあきらかに低下します。 途中結果を見る必要がなければ最初のコードのほうが良いでしょう。
生成結果
結果は次のようになります。ただし、乱数がいろいろなところで使われており、シードも固定されていないので、毎回結果は異なります。
学習時の処理
学習時の処理がどのように実装されているのかを見ていきます。
上位ループ
モデルの学習をするmain.pyの主要な処理である学習のループ部分を抜き出して、疑似コードとして書き換えると、
for epoch in range(100): for image, label in 学習用画像とそのラベルの集合: # 画像に加えるノイズの強さをランダムに選択 t = torch.randint(0, 1000, (image.shape[0],)) # 画像に加えるノイズを作成 epsilon = torch.randn(image.shape) # tとepsilonに基づいてノイズをimageに加える diffused_image = process.forward(image, t, epsilon) # modelを使って加えられたノイズを予測 optimizer.zero_grad() output = model(diffused_image, t, label) # 予測したノイズがどれだけ正しいかを評価して、モデルの重みを更新 loss = criterion(epsilon, output) # criterion = torch.nn.MSELoss() loss.backward() optimizer.step()となります。ノイズが加えられた画像から、加えられたノイズを推定するモデルを学習しているだけです。 それ以外は通常のニューラルネットワークの学習と変わりがありません。 そしてこれは元論文のAlgorithm 1の通りの実装です。5行目の
と見比べると、\(\bm{\epsilon}\)はコード上のepsilonに、 \(\bm{\epsilon_{\theta}}\)はコード上のmodelに、 \(\sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}\)はdiffused_imageに相当していることが分かります。 Algorithm 1とは異なり、modelの引数にlabelが余計についていますが、これは生成する数字が0〜9のどれであるかをニューラルネットワークに指示するための値となります。
\( \nabla_\theta \| \bm{\epsilon} - \bm{\epsilon}_\theta ( \sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}, t) \|^2 \)
ノイズ付加処理
上位ループ内の
diffused_image = process.forward(image, t, epsilon)の部分について、詳細を見ていきます。 まず、diffusion_model.pyのDiffusionProcessの初期化処理部分である__init__にて\(\alpha_t\)の値が計算されています。 最初に\(\beta_t\) (コード上ではvariance_schedule) を
self.variance_schedule = torch.linspace(1e-4, 0.01, steps=1000)のように計算しています。具体的には
self.variance_schedule = [1.0000e-04, 1.0991e-04, 1.1982e-04, ... , 0.009980, 0.009990, 0.01]となっています。元論文では\(\beta_1=10^4\)、\(\beta_T=0.02\)、\(T=1000\)と書かれているので\(\beta_T\)の値のみ異なっています。 \(\alpha_t = 1-\beta_t\)であるので、
self.alpha = 1 - self.variance_scheduleにて計算され、具体的な値は
self.alpha = [0.999900, 0.999890, 0.999880 , ... , 0.990020, 0.990010, 0.990000]となります。さらに、\(\bar{\alpha}_t=\prod_{s=1}^t \alpha_s\)であるので、
self.alpha_bar = torch.cumprod(self.alpha, dim=0)にて計算され、具体的な値は
self.alpha_bar = [0.999900, 0.999790, 0.999670, ... , 0.006430, 0.006365, 0.006302]となります。 これらの値を使って、forwardにてノイズを画像に付加します。 \(\sqrt{\bar{\alpha_t}}\bm{x}_0 + \sqrt{1-\bar{\alpha}_t}\bm{\epsilon}\)が計算できれば良いので、まず\(\sqrt{1-\bar{\alpha}_t}\)を
std_dev = torch.sqrt(1 - self.alpha_bar[time_step])で計算します。次に\(\sqrt{\bar{\alpha_t}}\)を
mean_multiplier = torch.sqrt(self.alpha_bar[time_step])で計算し、最後に
diffused_images = mean_multiplier * x_0 + std_dev * noiseのように足し合わせます。x_0が\(\bm{x}_0\)で、noiseが\(\bm{\epsilon}\)です。
ノイズを予測するモデル
ノイズを予測するモデルには小さめのUNetが使われています。コードからグラフに書き起こすと下図のようになります。
背景が水色のボックス内にはクラス名とインスタンス化時の引数を記載しています(画像をクリックすると拡大できます)。 降りる方向と登る方向では条件付けの場所が異なっていますが、このあたりはおおまかであっても十分に動くのでしょう。 以下では、Pytorchで提供されているクラス以外のものを見ていきます。
ResConvGroupNorm
このクラスは以下のような構成になっています。残差接続のある、よく見かける畳み込み演算を使ったブロックです。
LabelEmbedding
このクラスは、0〜9のラベルをベクトルに変換し、入力値\(x\)に埋め込む処理をします。画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。
TemporalEmbedding
このクラスは、加えるノイズの強さを表す時刻\(t\)を入力値\(x\)に埋め込む処理をします。こちらも画像のチャネル方向に埋め込むので、各チャネルの画像特徴が埋め込みによって全体的に明るくなったり暗くなったりすることになります。LabelEmbeddingと干渉しそうですが、モデルの学習時にうまく棲み分けしているのでしょう。 SinusoidalPositionEmbeddingsはarXiv:1706.03762の \[ PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}}) \] \[ PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}}) \] とほぼ同じ計算をしています。\(sin\)と\(cos\)の引数部分のみ取り出すと、 \[ pos/10000^{2i/d_{model}} \] となり、\(pos\)を時刻\(t\)、\(d_{model}\)を単に\(d\)と書き換えると、 \[ t/10000^{2i/d} \] となります。iは次元方向のインデックスです。もう少し書き換えて、 \[ t/10000^{i/(d/2)} \] \(i/(d/2)\)の範囲が[0,1]になるように、少し値は変わりますが、 \[ t/10000^{i/(d/2-1)} \] とします。さらに、 \[ \begin{aligned} \frac{t}{\exp\left(\log 10000^{i/(d/2-1)}\right)} &= t \cdot \exp \left(-\log 10000^{i/(d/2-1)} \right) \\ &= t \cdot \exp \left(-\frac{i}{\frac{d}{2}-1} \log 10000 \right) \end{aligned} \] と変形します。ここまで変形するとコードとの対応がとれ、
half_dim = self.dim // 2で\(d/2\)を計算し、
embeddings = math.log(10000) / (half_dim - 1)で、\(\frac{1}{\frac{d}{2}-1} \log 10000\)を計算し、続いて
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)で、\(\exp \left(-\frac{i}{\frac{d}{2}-1} \log 10000 \right)\)を計算していることが分かります。 なお、変形前の\(1/10000^{i/(d/2-1)}\)にしたがって
1/torch.pow(10000, torch.arange(half_dim, device=device)/(half_dim-1))で計算しても、結果はほとんど変わりません。具体的には、最初の5つの値を出力すると、
式変形前 [1.0000000000, 0.9821373820, 0.9645937681, 0.9473634958, 0.9304410219] 式変形後 [1.0000000000, 0.9821373224, 0.9645937085, 0.9473634958, 0.9304410219]となるので、どちらで計算してもよさそうです。 さて、時刻\(t\)をかけると
embeddings = time[:, None] * embeddings[None, :]となり、最後に、
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)のように\(sin\)と\(cos\)を通し、それらを連結することで、埋め込みベクトルが作成できます。
LinearAttention
attention1のクラスであるLinearAttentionの実装のアイデアの元と思われる論文は のあたりのようですが、ぴったり当てはまるものを見つけることはできませんでした。 注意機構部分のみをコードから計算式に戻すと、 \[ d_{t'f'} = \frac{1}{\sqrt{32}}\frac{1}{7 \times 7} \sum_f{\frac{\exp(q_{t'f})}{\sum_{f''}{\exp(q_{t'f''})}} \sum_t{\frac{\exp(k_{tf})}{\sum_{t''}\exp(k_{t''f})} v_{tf'}}} \] となっているようです。普通の注意機構とは異なり、最初に\(K\)の特徴次元数×\(V\)の特徴次元数(今の場合は32×32)の行列を計算しているようです。 \(\frac{1}{\sqrt{32}}\)の32は\(K\)の特徴次元数で、arXiv:1706.03762の式(1)である \[ Attention(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \] の\(d_k\)に相当するようです。
生成時の処理
生成時の処理がどのように実装されているのかを見ていきます。
上位ループ
生成処理をしているinference_unet.pyの主要部分を疑似コードとして書き換えると、
for t in trange(999, -1, -1): # 時刻tとラベルlabelの条件の下、ノイズを推定する et = model(xt, t, label) # 推定したノイズを使って画像からノイズを除去する xt = process.inverse(xt, et, t) # xtを画像として描画する draw(xt)となります。モデルの学習時は時刻tをランダムに選んでいましたが、 生成するときはノイズから徐々にノイズを取り除いていきます。t=999ではノイズのみ、t=0ではノイズがなくなった画像になります。
ノイズ除去処理
上位ループ内の
xt = process.inverse(xt, et, t)の部分について詳細を見ていきます。 元論文arXiv:2006.11239のAlgorithm 2の4行目の処理 \[ \bm{x}_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(\bm{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\bm{x}_t,t)\right)+\sigma_t\bm{z} \] が実装されています。DiffusionProcessのinverseを見ていくと、
scale = 1 / torch.sqrt(self.alpha[t])では\(\frac{1}{\sqrt{\alpha_t}}\)の部分が計算されています。
noise_scale = (1 - self.alpha[t]) / torch.sqrt(1 - self.alpha_bar[t])では\(\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\)の部分が計算されています。
std_dev = torch.sqrt(self.variance_schedule[t])では\(\sigma_t\)が計算されています。
z = torch.randn(xt.shape) if t > 1 else torch.Tensor([0])では\(\bm{z} \sim N(0,\bm{I})\)が計算されています。ガウス分布に従うノイズを作っているだけです。 最後にすべてをつなげて
mu_t = scale * (xt - noise_scale * et) xt = mu_t + std_dev * z # remove noise from imageを計算することで、Algorithm 2の4行目の処理を実現しています。
0 件のコメント :
コメントを投稿