はじめに
その1で試したDiffusion MNISTについて、 ノイズの強さを表す時刻\(t\)をニューラルネットワークに伝えないとどうなるのかを見ていきます。
方法
https://github.com/MarceloGennari/diffusion_mnist のConditionalUNetの\(t\)が関連する行、つまり、TemporalEmbedding部分をコメントアウトします。具体的には
class ConditionalUNet(UNet): (省略) def forward(self, x: Tensor, t: Tensor, label: Tensor) -> Tensor: x0 = x #self.embedding1(x, t) x1 = self.block1(x0) x1 = self.label_emb1(x1, label) #x1 = self.embedding2(x1, t) x2 = self.block2(self.down1(x1)) x2 = self.label_emb2(x2, label) #x2 = self.embedding3(x2, t) crossed = self.label_emb3(self.block3(self.down2(x2)), label) x3 = self.up1(self.attention1(crossed)) x4 = torch.cat([x2, x3], dim=1) #x4 = self.embedding4(x4, t) x5 = self.up2(self.label_emb4(self.block4(x4), label)) x6 = torch.cat([x5, x1], dim=1) x6 = self.label_emb5(x6, label) #x6 = self.embedding5(x6, t) out = self.out(self.block5(x6)) return outとします。
結果
その2で試した時間刻みを100にしたバージョンをベースに比較します。左側がTemporalEmbeddingありで、右側がなしに対応します。
TemporalEmbeddingなしの場合はノイズが多いように見えるので、画像の明るさをGIMPを使って上げたものが下図です。TemporalEmbeddingありではノイズが見えませんが、TemporalEmbeddingなしではノイズがはっきり見えるケースが多くなっています。
まとめ
時刻の埋め込みは効果があるということを確認できました。
0 件のコメント :
コメントを投稿