2023/10/26

Diffusion MNIST その3

はじめに


その1で試したDiffusion MNISTについて、 ノイズの強さを表す時刻\(t\)をニューラルネットワークに伝えないとどうなるのかを見ていきます。

方法


https://github.com/MarceloGennari/diffusion_mnist のConditionalUNetの\(t\)が関連する行、つまり、TemporalEmbedding部分をコメントアウトします。具体的には
class ConditionalUNet(UNet):
    (省略)
    def forward(self, x: Tensor, t: Tensor, label: Tensor) -> Tensor:
        x0 = x #self.embedding1(x, t)
        x1 = self.block1(x0)
        x1 = self.label_emb1(x1, label)
        #x1 = self.embedding2(x1, t)
        x2 = self.block2(self.down1(x1))
        x2 = self.label_emb2(x2, label)
        #x2 = self.embedding3(x2, t)
        crossed = self.label_emb3(self.block3(self.down2(x2)), label)
        x3 = self.up1(self.attention1(crossed))
        x4 = torch.cat([x2, x3], dim=1)
        #x4 = self.embedding4(x4, t)
        x5 = self.up2(self.label_emb4(self.block4(x4), label))
        x6 = torch.cat([x5, x1], dim=1)
        x6 = self.label_emb5(x6, label)
        #x6 = self.embedding5(x6, t)
        out = self.out(self.block5(x6))
        return out
とします。

結果


その2で試した時間刻みを100にしたバージョンをベースに比較します。左側がTemporalEmbeddingありで、右側がなしに対応します。
t=50 TemporalEmbeddingあり
t=50 TemporalEmbeddingなし

t=0 TemporalEmbeddingあり
t=0 TemporalEmbeddingなし

TemporalEmbeddingなしの場合はノイズが多いように見えるので、画像の明るさをGIMPを使って上げたものが下図です。TemporalEmbeddingありではノイズが見えませんが、TemporalEmbeddingなしではノイズがはっきり見えるケースが多くなっています。
t=0 TemporalEmbeddingあり
t=0 TemporalEmbeddingなし

まとめ


時刻の埋め込みは効果があるということを確認できました。

0 件のコメント :