2024/05/14

KANでMNIST

KANとは


Kolmogorov–Arnold Networkの略で、Multi-Layer Perceptron (MLP) の代わりに使えるニューラルネットワークです。

MLPでは、入力データに対して重み付き線形和を計算し、活性化関数(例えばReLU)に通す、という処理を層の数だけ繰り返します。 和をとったあとに活性化関数を適用することになります。 入力が\(N_{\rm in}\)次元、出力が\(N_{\rm out}\)次元のとき、活性化関数は\(N_{\rm out}\)回だけ実行されます。

KANでは、入力データに対してB-スプライン曲線で学習可能にした活性化関数を通し、和を取る、という処理を層の数だけ繰り返します。 入力が\(N_{\rm in}\)次元、出力が\(N_{\rm out}\)次元のとき、活性化関数は\(N_{\rm in} \times N_{\rm out}\)回だけ実行されます。 また、活性化関数は学習可能なので、\(N_{\rm in} \times N_{\rm out}\)個のそれぞれ異なる形状の活性化関数が存在します。

B-スプライン曲線で活性化関数を書くといっても、お絵描きするときのように2Dの曲線を書いて活性化関数にするのではなく、基底関数の線形和を計算するだけです。 具体的には \[ {\rm spline}(x) = \sum_i c_i B_i(x) \] となります。\(B_i\)はB-スプライン曲線で指定した点を使って曲線を描くときにどのように補間するか(どの割合で点の位置を混ぜるか)を計算する関数です。 それを学習可能な\(c_i\)で混ぜて活性化関数とするわけです。

活性化関数\(\phi(x)\)には\({\rm spline}(x)\)を直接そのまま使うのではなく \[\phi(x) = w (b(x) + {\rm spline}(x))\] を用います。ここで、 \[b(x)={\rm silu}(x)=\frac{x}{1+e^{-x}} \] です。\(w\)は学習可能な重みです。ただし、Githubで公開されているコードを読むと \[\phi(x) = w_{\rm base} b(x) + w_{\rm sp} {\rm spline}(x)\] が使われているように見えます(KANLayer.pyを参照)。

B-スプライン曲線についてはhttps://techblog.kayac.com/generate-curves-using-b-splineとかhttp://web.mit.edu/hyperbook/Patrikalakis-Maekawa-Cho/node17.htmlが分かりやすいです。

KANの面白いところは、B-スプライン曲線で作った活性化関数が \(x^2\)、\({\rm exp}(x)\)、\({\rm sin}(x)\)、\({\rm log}(x)\)、\({\rm sqrt}(x)\)、\({\rm abs}(x)\) のようなユーザーが指定できる関数に十分近い場合はそれに置き換えてしまうことができる点です。

任意の入力に対して出力が0になる活性化関数を正則化によって増やし、それらを除去していくと、有効な活性化関数が人間が理解できる程度に少なくなることがあります。このとき、入力\(x\)に対して出力を\(y=f(x)\)で計算できる場合、関数\(f\)をユーザーが指定した活性化関数を使って作った合成関数と線形和、例えば \(f(x) = 1.2 \times {\rm sin} (x^2 - 0.3) + 0.5\) のような人間がみて分かる数式で出力することができます。

論文はhttps://arxiv.org/abs/2404.19756で、 コードはhttps://github.com/KindXiaoming/pykanにあります。 ここではリビジョン e6078bc8 を使います。

MNISTで学習させてみる


KAN Layerを使ってモデルを作り、MNISTで学習させてみます。

すべてKAN Layerで作ることもできるのですが、MNISTの画像は28×28=784と大きく、これを入力として32次元のベクトルを出力するようにすると、1層だけで784×32=25088個ものB-スプライン曲線を学習することになります。実行自体はできるのですが、非常に遅いため、ここでは最初にConv2Dで次元数を減らしてからKAN Layerを利用することにします。

学習用のコードは以下のとおりです。なお、著者が公開しているpykanのKAN.pyをベースに色々書き換えているので、もとのコードのライセンスに従い、このコードの部分はMITライセンスとします。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pykan.kan.KANLayer import KANLayer
import sys

def initialize_seed(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

class ConvMLP(nn.Module):
    def __init__(self, fc_layers: list[int], device):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=2, device=device)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, device=device)
        n_in = 16
        fcs = []
        for fc in fc_layers:
            fcs.append(nn.Linear(n_in, fc, device=device))
            n_in = fc
        self.fcs = nn.ModuleList(fcs)

    def forward(self, x):
        x = self.conv2(F.relu(F.max_pool2d(self.conv1(x), 2)))
        x = x.reshape(x.shape[0], -1)
        for fc in self.fcs:
            x = fc(F.relu(x))
        return x

    def update_grid_from_samples(self, x):
        pass

    def regularize(self, lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff, small_mag_threshold=1e-16, small_reg_factor=1.0):
        return 0.0

# Modified version of KAN in pykan/kan/KAN.py
class ConvKAN(nn.Module):
    def __init__(self,
                 width: list[int],
                 grid=5,
                 k=3,
                 noise_scale=0.1,
                 noise_scale_base=0.1,
                 base_fun=torch.nn.SiLU(),
                 bias_trainable=True,
                 grid_eps=1.0,
                 grid_range=[-1, 1],
                 sp_trainable=True,
                 sb_trainable=True,
                 device="cpu"):
        super().__init__()

        ### Initialize feature extraction layers
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=2, device=device)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, device=device)
        width.insert(0, 16)

        ### Initialize KAN layers
        self.biases = []
        self.act_fun = []
        self.depth = len(width) - 1
        self.width = width

        for l in range(self.depth):
            # splines
            scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base
            sp_batch = KANLayer(in_dim=width[l],
                                out_dim=width[l + 1],
                                num=grid,
                                k=k,
                                noise_scale=noise_scale,
                                scale_base=scale_base,
                                scale_sp=1.0,
                                base_fun=base_fun,
                                grid_eps=grid_eps,
                                grid_range=grid_range,
                                sp_trainable=sp_trainable,
                                sb_trainable=sb_trainable,
                                device=device)
            self.act_fun.append(sp_batch)

            # bias
            bias = nn.Linear(width[l + 1], 1, bias=False, device=device).requires_grad_(bias_trainable)
            bias.weight.data *= 0.0
            self.biases.append(bias)

        self.biases = nn.ModuleList(self.biases)
        self.act_fun = nn.ModuleList(self.act_fun)

    def forward(self, x):
        # Extract features by conv
        x = self.conv2(F.relu(F.max_pool2d(self.conv1(x), 2)))
        x = x.reshape(x.shape[0], -1)

        # Run KAN layers
        self.acts = [x] # acts shape: (batch, width[l])
        self.acts_scale = []

        for l in range(self.depth):
            x, preacts, postacts, postspline = self.act_fun[l](x)
            grid_reshape = self.act_fun[l].grid.reshape(self.width[l + 1], self.width[l], -1)
            input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4
            output_range = torch.mean(torch.abs(postacts), dim=0)
            self.acts_scale.append(output_range / input_range)

            x = x + self.biases[l].weight
            self.acts.append(x)

        return x

    def update_grid_from_samples(self, x):
        for l in range(self.depth):
            self.forward(x)
            self.act_fun[l].update_grid_from_samples(self.acts[l])

    def regularize(self, lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff, small_mag_threshold=1e-16, small_reg_factor=1.0):
        def nonlinear(x, th, factor):
            return (x < th) * x * factor + (x > th) * (x + (factor - 1) * th)

        reg_ = 0.
        for i in range(len(self.acts_scale)):
            vec = self.acts_scale[i].reshape(-1, )
            vec_sum = torch.sum(vec)
            if vec_sum == 0.0:
                continue

            p = vec / vec_sum
            l1 = torch.sum(nonlinear(vec, th=small_mag_threshold, factor=small_reg_factor))
            entropy = - torch.sum(p * torch.log2(p + 1e-4))
            reg_ += lambda_l1 * l1 + lambda_entropy * entropy  # both l1 and entropy

        # regularize coefficient to encourage spline to be zero
        for i in range(len(self.act_fun)):
            coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
            coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1))
            reg_ += lambda_coef * coeff_l1 + lambda_coefdiff * coeff_diff_l1

        return reg_

def calc_accuracy(ys):
    rs = []
    for y, label in ys:
        rs.append((torch.argmax(y, dim=1) == label).float())
    r = torch.cat(rs, dim=0)
    return torch.mean(r)*100.0

def train(model,
          train_loader,
          test_loader,
          max_epoch,
          lamb=0.0,
          lambda_l1=1.0,
          lambda_entropy=2.0,
          lambda_coef=0.0,
          lambda_coefdiff=0.0,
          update_grid=True,
          grid_update_freq=10,
          loss_fn=torch.nn.CrossEntropyLoss(),
          lr=0.002,
          device="cpu"):

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(max_epoch):
        model.train()
        n_samples = 0
        max_samples = len(train_loader.dataset)
        ys = []
        for iter, (x, label) in enumerate(train_loader):
            x = x.to(device)
            label = label.to(device)
            if iter % grid_update_freq == 0 and update_grid:
                model.update_grid_from_samples(x)
            y = model(x)
            ys.append((y, label))
            loss = loss_fn(y, label)
            reg_ = model.regularize(lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff)
            loss = loss + lamb * reg_
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            n_samples += len(x)
            if iter % 100 == 0:
                print(f"Epoch: {epoch} [{n_samples}/{max_samples}] Loss: {loss.item():.6f}")

        # Calc train accuracy
        train_acc = calc_accuracy(ys)

        # Calc test accuracy
        model.eval()
        ys = []
        with torch.no_grad():
            for iter, (x, label) in enumerate(test_loader):
                x = x.to(device)
                label = label.to(device)
                y = model(x)
                ys.append((y, label))
        test_acc = calc_accuracy(ys)
        print(f"Epoch: {epoch} [{n_samples}/{max_samples}] Loss: {loss.item():.6f} Acc(Train): {train_acc} Acc(Test): {test_acc}")

    return

def main(mode):
    initialize_seed(123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reg_lambda = 0.0
    update_grid = True,
    if mode == "kan":
        model = ConvKAN(width=[20, 10], device=device)
    elif mode == "kan-no-update-grid":
        model = ConvKAN(width=[20, 10], device=device)
        update_grid = False
    elif mode == "kan-reg":
        model = ConvKAN(width=[20, 10], device=device)
        reg_lambda = 0.003
    elif mode == "mlp":
        model = ConvMLP(fc_layers=[20, 10], device=device)
    else:
        return
    train_loader = DataLoader(datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True)
    test_loader = DataLoader(datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=False)
    train(model, train_loader, test_loader, max_epoch=5, lamb=reg_lambda, update_grid=update_grid, device=device)

if __name__ == "__main__":
    main(sys.argv[1])

ConvMLP


比較用としてMLPを使った単純なモデルで学習させてみた結果が以下です。
Epoch: 0 [128/60000] Loss: 2.330844
Epoch: 0 [12928/60000] Loss: 0.527279
Epoch: 0 [25728/60000] Loss: 0.264145
Epoch: 0 [38528/60000] Loss: 0.268547
Epoch: 0 [51328/60000] Loss: 0.270200
Epoch: 0 [60000/60000] Loss: 0.137060 Acc(Train): 82.13166809082031 Acc(Test): 93.66999816894531
Epoch: 1 [128/60000] Loss: 0.211759
Epoch: 1 [12928/60000] Loss: 0.172556
Epoch: 1 [25728/60000] Loss: 0.216685
Epoch: 1 [38528/60000] Loss: 0.224366
Epoch: 1 [51328/60000] Loss: 0.166860
Epoch: 1 [60000/60000] Loss: 0.297034 Acc(Train): 94.53333282470703 Acc(Test): 95.87000274658203
Epoch: 2 [128/60000] Loss: 0.169406
Epoch: 2 [12928/60000] Loss: 0.060187
Epoch: 2 [25728/60000] Loss: 0.065257
Epoch: 2 [38528/60000] Loss: 0.209189
Epoch: 2 [51328/60000] Loss: 0.137371
Epoch: 2 [60000/60000] Loss: 0.062734 Acc(Train): 95.69499969482422 Acc(Test): 96.29000091552734
Epoch: 3 [128/60000] Loss: 0.161130
Epoch: 3 [12928/60000] Loss: 0.094211
Epoch: 3 [25728/60000] Loss: 0.137475
Epoch: 3 [38528/60000] Loss: 0.143321
Epoch: 3 [51328/60000] Loss: 0.086744
Epoch: 3 [60000/60000] Loss: 0.295325 Acc(Train): 96.2300033569336 Acc(Test): 96.52999877929688
Epoch: 4 [128/60000] Loss: 0.084705
Epoch: 4 [12928/60000] Loss: 0.114923
Epoch: 4 [25728/60000] Loss: 0.071916
Epoch: 4 [38528/60000] Loss: 0.093224
Epoch: 4 [51328/60000] Loss: 0.111265
Epoch: 4 [60000/60000] Loss: 0.049419 Acc(Train): 96.69833374023438 Acc(Test): 97.1500015258789

ConvKAN


KANを使ったモデルで学習させてみた結果が以下です。Conv2Dが色々吸収してしまっているのかもしれませんが、違いがほとんどありません。 パラメータ数はConvMLPよりConvKANのほうが多いです。
Epoch: 0 [128/60000] Loss: 2.308975
Epoch: 0 [12928/60000] Loss: 0.608790
Epoch: 0 [25728/60000] Loss: 0.333690
Epoch: 0 [38528/60000] Loss: 0.214927
Epoch: 0 [51328/60000] Loss: 0.171603
Epoch: 0 [60000/60000] Loss: 0.099660 Acc(Train): 88.74166870117188 Acc(Test): 96.1500015258789
Epoch: 1 [128/60000] Loss: 0.073898
Epoch: 1 [12928/60000] Loss: 0.215137
Epoch: 1 [25728/60000] Loss: 0.109934
Epoch: 1 [38528/60000] Loss: 0.126619
Epoch: 1 [51328/60000] Loss: 0.066091
Epoch: 1 [60000/60000] Loss: 0.188135 Acc(Train): 96.20833587646484 Acc(Test): 96.88999938964844
Epoch: 2 [128/60000] Loss: 0.061085
Epoch: 2 [12928/60000] Loss: 0.078620
Epoch: 2 [25728/60000] Loss: 0.045636
Epoch: 2 [38528/60000] Loss: 0.052172
Epoch: 2 [51328/60000] Loss: 0.037537
Epoch: 2 [60000/60000] Loss: 0.149782 Acc(Train): 96.87333679199219 Acc(Test): 95.19000244140625
Epoch: 3 [128/60000] Loss: 0.040984
Epoch: 3 [12928/60000] Loss: 0.102282
Epoch: 3 [25728/60000] Loss: 0.017132
Epoch: 3 [38528/60000] Loss: 0.043684
Epoch: 3 [51328/60000] Loss: 0.126490
Epoch: 3 [60000/60000] Loss: 0.057794 Acc(Train): 97.1483383178711 Acc(Test): 97.47000122070312
Epoch: 4 [128/60000] Loss: 0.056742
Epoch: 4 [12928/60000] Loss: 0.087390
Epoch: 4 [25728/60000] Loss: 0.046712
Epoch: 4 [38528/60000] Loss: 0.058726
Epoch: 4 [51328/60000] Loss: 0.217805
Epoch: 4 [60000/60000] Loss: 0.137546 Acc(Train): 97.54500579833984 Acc(Test): 97.52999877929688

ConvKANでgridの更新なし


KANを使ったモデルでgridの更新なしで学習させてみた結果が以下です。B-スプライン曲線による活性化関数は処理できる入力値の範囲が決まっており、gridの更新なしというのは、その範囲の調整を行わないということです。今回の設定では特に効果が無いようです。
Epoch: 0 [128/60000] Loss: 2.308975
Epoch: 0 [12928/60000] Loss: 0.502888
Epoch: 0 [25728/60000] Loss: 0.330806
Epoch: 0 [38528/60000] Loss: 0.242808
Epoch: 0 [51328/60000] Loss: 0.165698
Epoch: 0 [60000/60000] Loss: 0.191306 Acc(Train): 86.51166534423828 Acc(Test): 95.06999969482422
Epoch: 1 [128/60000] Loss: 0.090638
Epoch: 1 [12928/60000] Loss: 0.238503
Epoch: 1 [25728/60000] Loss: 0.128466
Epoch: 1 [38528/60000] Loss: 0.166372
Epoch: 1 [51328/60000] Loss: 0.120421
Epoch: 1 [60000/60000] Loss: 0.113729 Acc(Train): 95.87166595458984 Acc(Test): 96.37999725341797
Epoch: 2 [128/60000] Loss: 0.094055
Epoch: 2 [12928/60000] Loss: 0.101774
Epoch: 2 [25728/60000] Loss: 0.053376
Epoch: 2 [38528/60000] Loss: 0.050028
Epoch: 2 [51328/60000] Loss: 0.049892
Epoch: 2 [60000/60000] Loss: 0.045015 Acc(Train): 96.88333129882812 Acc(Test): 96.44000244140625
Epoch: 3 [128/60000] Loss: 0.091558
Epoch: 3 [12928/60000] Loss: 0.116339
Epoch: 3 [25728/60000] Loss: 0.035658
Epoch: 3 [38528/60000] Loss: 0.047689
Epoch: 3 [51328/60000] Loss: 0.121902
Epoch: 3 [60000/60000] Loss: 0.066078 Acc(Train): 97.5 Acc(Test): 97.30999755859375
Epoch: 4 [128/60000] Loss: 0.069760
Epoch: 4 [12928/60000] Loss: 0.048372
Epoch: 4 [25728/60000] Loss: 0.042325
Epoch: 4 [38528/60000] Loss: 0.073130
Epoch: 4 [51328/60000] Loss: 0.130568
Epoch: 4 [60000/60000] Loss: 0.046441 Acc(Train): 97.72833251953125 Acc(Test): 97.43999481201172

ConvKANで正則化あり


KANを使ったモデルで正則化ありで実行してみます。正則化のロスの重み\(\lambda\)は0.003にしています。
Epoch: 0 [128/60000] Loss: 2.449388
Epoch: 0 [12928/60000] Loss: 0.862422
Epoch: 0 [25728/60000] Loss: 0.502983
Epoch: 0 [38528/60000] Loss: 0.569132
Epoch: 0 [51328/60000] Loss: 0.391648
Epoch: 0 [60000/60000] Loss: 0.360877 Acc(Train): 89.41166687011719 Acc(Test): 95.72999572753906
Epoch: 1 [128/60000] Loss: 0.345524
Epoch: 1 [12928/60000] Loss: 0.449618
Epoch: 1 [25728/60000] Loss: 0.375242
Epoch: 1 [38528/60000] Loss: 0.334639
Epoch: 1 [51328/60000] Loss: 0.313511
Epoch: 1 [60000/60000] Loss: 0.374076 Acc(Train): 96.25166320800781 Acc(Test): 96.95999908447266
Epoch: 2 [128/60000] Loss: 0.237053
Epoch: 2 [12928/60000] Loss: 0.721758
Epoch: 2 [25728/60000] Loss: 0.468227
Epoch: 2 [38528/60000] Loss: 0.489218
Epoch: 2 [51328/60000] Loss: 0.366972
Epoch: 2 [60000/60000] Loss: 0.461017 Acc(Train): 90.59166717529297 Acc(Test): 94.61000061035156
Epoch: 3 [128/60000] Loss: 0.391373
Epoch: 3 [12928/60000] Loss: 0.405158
Epoch: 3 [25728/60000] Loss: 0.296408
Epoch: 3 [38528/60000] Loss: 0.306373
Epoch: 3 [51328/60000] Loss: 0.362922
Epoch: 3 [60000/60000] Loss: 0.335211 Acc(Train): 94.8933334350586 Acc(Test): 95.55999755859375
Epoch: 4 [128/60000] Loss: 0.343693
Epoch: 4 [12928/60000] Loss: 0.300828
Epoch: 4 [25728/60000] Loss: 0.310633
Epoch: 4 [38528/60000] Loss: 0.402329
Epoch: 4 [51328/60000] Loss: 0.401533
Epoch: 4 [60000/60000] Loss: 0.272522 Acc(Train): 95.64000701904297 Acc(Test): 95.79000091552734
\(\lambda=0.005\)にすると、以下のようになり途中でモデルが崩壊しました。
Epoch: 0 [128/60000] Loss: 2.542997
Epoch: 0 [12928/60000] Loss: 0.992837
Epoch: 0 [25728/60000] Loss: 0.666648
Epoch: 0 [38528/60000] Loss: 0.642962
Epoch: 0 [51328/60000] Loss: 0.496544
Epoch: 0 [60000/60000] Loss: 0.467141 Acc(Train): 88.85499572753906 Acc(Test): 95.80999755859375
Epoch: 1 [128/60000] Loss: 0.465318
Epoch: 1 [12928/60000] Loss: 0.551555
Epoch: 1 [25728/60000] Loss: 0.480575
Epoch: 1 [38528/60000] Loss: 0.532717
Epoch: 1 [51328/60000] Loss: 0.389326
Epoch: 1 [60000/60000] Loss: 0.464731 Acc(Train): 95.5183334350586 Acc(Test): 96.31999969482422
Epoch: 2 [128/60000] Loss: 0.360698
Epoch: 2 [12928/60000] Loss: 0.528383
Epoch: 2 [25728/60000] Loss: 0.377956
Epoch: 2 [38528/60000] Loss: 0.333951
Epoch: 2 [51328/60000] Loss: 0.379329
Epoch: 2 [60000/60000] Loss: 0.398343 Acc(Train): 94.8566665649414 Acc(Test): 96.06999969482422
Epoch: 3 [128/60000] Loss: 0.374904
Epoch: 3 [12928/60000] Loss: 0.928802
Epoch: 3 [25728/60000] Loss: 0.401091
Epoch: 3 [38528/60000] Loss: 0.539629
Epoch: 3 [51328/60000] Loss: 223.694611
Epoch: 3 [60000/60000] Loss: 3.871729 Acc(Train): 76.33499908447266 Acc(Test): 10.520000457763672
Epoch: 4 [128/60000] Loss: 4.122091
Epoch: 4 [12928/60000] Loss: 3.226183
Epoch: 4 [25728/60000] Loss: 2.973773
Epoch: 4 [38528/60000] Loss: 2.831389
Epoch: 4 [51328/60000] Loss: 2.789084
Epoch: 4 [60000/60000] Loss: 2.598794 Acc(Train): 9.819999694824219 Acc(Test): 10.09999942779541

0 件のコメント :