KANとは
Kolmogorov–Arnold Networkの略で、Multi-Layer Perceptron (MLP) の代わりに使えるニューラルネットワークです。 MLPでは、入力データに対して重み付き線形和を計算し、活性化関数(例えばReLU)に通す、という処理を層の数だけ繰り返します。 和をとったあとに活性化関数を適用することになります。 入力が\(N_{\rm in}\)次元、出力が\(N_{\rm out}\)次元のとき、活性化関数は\(N_{\rm out}\)回だけ実行されます。 KANでは、入力データに対してB-スプライン曲線で学習可能にした活性化関数を通し、和を取る、という処理を層の数だけ繰り返します。 入力が\(N_{\rm in}\)次元、出力が\(N_{\rm out}\)次元のとき、活性化関数は\(N_{\rm in} \times N_{\rm out}\)回だけ実行されます。 また、活性化関数は学習可能なので、\(N_{\rm in} \times N_{\rm out}\)個のそれぞれ異なる形状の活性化関数が存在します。 B-スプライン曲線で活性化関数を書くといっても、お絵描きするときのように2Dの曲線を書いて活性化関数にするのではなく、基底関数の線形和を計算するだけです。 具体的には \[ {\rm spline}(x) = \sum_i c_i B_i(x) \] となります。\(B_i\)はB-スプライン曲線で指定した点を使って曲線を描くときにどのように補間するか(どの割合で点の位置を混ぜるか)を計算する関数です。 それを学習可能な\(c_i\)で混ぜて活性化関数とするわけです。 活性化関数\(\phi(x)\)には\({\rm spline}(x)\)を直接そのまま使うのではなく \[\phi(x) = w (b(x) + {\rm spline}(x))\] を用います。ここで、 \[b(x)={\rm silu}(x)=\frac{x}{1+e^{-x}} \] です。\(w\)は学習可能な重みです。ただし、Githubで公開されているコードを読むと \[\phi(x) = w_{\rm base} b(x) + w_{\rm sp} {\rm spline}(x)\] が使われているように見えます(KANLayer.pyを参照)。 B-スプライン曲線についてはhttps://techblog.kayac.com/generate-curves-using-b-splineとかhttp://web.mit.edu/hyperbook/Patrikalakis-Maekawa-Cho/node17.htmlが分かりやすいです。 KANの面白いところは、B-スプライン曲線で作った活性化関数が \(x^2\)、\({\rm exp}(x)\)、\({\rm sin}(x)\)、\({\rm log}(x)\)、\({\rm sqrt}(x)\)、\({\rm abs}(x)\) のようなユーザーが指定できる関数に十分近い場合はそれに置き換えてしまうことができる点です。 任意の入力に対して出力が0になる活性化関数を正則化によって増やし、それらを除去していくと、有効な活性化関数が人間が理解できる程度に少なくなることがあります。このとき、入力\(x\)に対して出力を\(y=f(x)\)で計算できる場合、関数\(f\)をユーザーが指定した活性化関数を使って作った合成関数と線形和、例えば \(f(x) = 1.2 \times {\rm sin} (x^2 - 0.3) + 0.5\) のような人間がみて分かる数式で出力することができます。 論文はhttps://arxiv.org/abs/2404.19756で、 コードはhttps://github.com/KindXiaoming/pykanにあります。 ここではリビジョン e6078bc8 を使います。
MNISTで学習させてみる
KAN Layerを使ってモデルを作り、MNISTで学習させてみます。 すべてKAN Layerで作ることもできるのですが、MNISTの画像は28×28=784と大きく、これを入力として32次元のベクトルを出力するようにすると、1層だけで784×32=25088個ものB-スプライン曲線を学習することになります。実行自体はできるのですが、非常に遅いため、ここでは最初にConv2Dで次元数を減らしてからKAN Layerを利用することにします。 学習用のコードは以下のとおりです。なお、著者が公開しているpykanのKAN.pyをベースに色々書き換えているので、もとのコードのライセンスに従い、このコードの部分はMITライセンスとします。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | import matplotlib.pyplot as plt import numpy as np import random import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader from pykan.kan.KANLayer import KANLayer import sys def initialize_seed(seed=0): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) class ConvMLP(nn.Module): def __init__(self, fc_layers: list[int], device): super().__init__() self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=2, device=device) self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, device=device) n_in = 16 fcs = [] for fc in fc_layers: fcs.append(nn.Linear(n_in, fc, device=device)) n_in = fc self.fcs = nn.ModuleList(fcs) def forward(self, x): x = self.conv2(F.relu(F.max_pool2d(self.conv1(x), 2))) x = x.reshape(x.shape[0], -1) for fc in self.fcs: x = fc(F.relu(x)) return x def update_grid_from_samples(self, x): pass def regularize(self, lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff, small_mag_threshold=1e-16, small_reg_factor=1.0): return 0.0 # Modified version of KAN in pykan/kan/KAN.py class ConvKAN(nn.Module): def __init__(self, width: list[int], grid=5, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(), bias_trainable=True, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device="cpu"): super().__init__() ### Initialize feature extraction layers self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=2, device=device) self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, device=device) width.insert(0, 16) ### Initialize KAN layers self.biases = [] self.act_fun = [] self.depth = len(width) - 1 self.width = width for l in range(self.depth): # splines scale_base = 1 / np.sqrt(width[l]) + (torch.randn(width[l] * width[l + 1], ) * 2 - 1) * noise_scale_base sp_batch = KANLayer(in_dim=width[l], out_dim=width[l + 1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1.0, base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, device=device) self.act_fun.append(sp_batch) # bias bias = nn.Linear(width[l + 1], 1, bias=False, device=device).requires_grad_(bias_trainable) bias.weight.data *= 0.0 self.biases.append(bias) self.biases = nn.ModuleList(self.biases) self.act_fun = nn.ModuleList(self.act_fun) def forward(self, x): # Extract features by conv x = self.conv2(F.relu(F.max_pool2d(self.conv1(x), 2))) x = x.reshape(x.shape[0], -1) # Run KAN layers self.acts = [x] # acts shape: (batch, width[l]) self.acts_scale = [] for l in range(self.depth): x, preacts, postacts, postspline = self.act_fun[l](x) grid_reshape = self.act_fun[l].grid.reshape(self.width[l + 1], self.width[l], -1) input_range = grid_reshape[:, :, -1] - grid_reshape[:, :, 0] + 1e-4 output_range = torch.mean(torch.abs(postacts), dim=0) self.acts_scale.append(output_range / input_range) x = x + self.biases[l].weight self.acts.append(x) return x def update_grid_from_samples(self, x): for l in range(self.depth): self.forward(x) self.act_fun[l].update_grid_from_samples(self.acts[l]) def regularize(self, lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff, small_mag_threshold=1e-16, small_reg_factor=1.0): def nonlinear(x, th, factor): return (x < th) * x * factor + (x > th) * (x + (factor - 1) * th) reg_ = 0. for i in range(len(self.acts_scale)): vec = self.acts_scale[i].reshape(-1, ) vec_sum = torch.sum(vec) if vec_sum == 0.0: continue p = vec / vec_sum l1 = torch.sum(nonlinear(vec, th=small_mag_threshold, factor=small_reg_factor)) entropy = - torch.sum(p * torch.log2(p + 1e-4)) reg_ += lambda_l1 * l1 + lambda_entropy * entropy # both l1 and entropy # regularize coefficient to encourage spline to be zero for i in range(len(self.act_fun)): coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1)) coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)), dim=1)) reg_ += lambda_coef * coeff_l1 + lambda_coefdiff * coeff_diff_l1 return reg_ def calc_accuracy(ys): rs = [] for y, label in ys: rs.append((torch.argmax(y, dim=1) == label).float()) r = torch.cat(rs, dim=0) return torch.mean(r)*100.0 def train(model, train_loader, test_loader, max_epoch, lamb=0.0, lambda_l1=1.0, lambda_entropy=2.0, lambda_coef=0.0, lambda_coefdiff=0.0, update_grid=True, grid_update_freq=10, loss_fn=torch.nn.CrossEntropyLoss(), lr=0.002, device="cpu"): optimizer = torch.optim.Adam(model.parameters(), lr=lr) for epoch in range(max_epoch): model.train() n_samples = 0 max_samples = len(train_loader.dataset) ys = [] for iter, (x, label) in enumerate(train_loader): x = x.to(device) label = label.to(device) if iter % grid_update_freq == 0 and update_grid: model.update_grid_from_samples(x) y = model(x) ys.append((y, label)) loss = loss_fn(y, label) reg_ = model.regularize(lambda_l1, lambda_entropy, lambda_coef, lambda_coefdiff) loss = loss + lamb * reg_ optimizer.zero_grad() loss.backward() optimizer.step() n_samples += len(x) if iter % 100 == 0: print(f"Epoch: {epoch} [{n_samples}/{max_samples}] Loss: {loss.item():.6f}") # Calc train accuracy train_acc = calc_accuracy(ys) # Calc test accuracy model.eval() ys = [] with torch.no_grad(): for iter, (x, label) in enumerate(test_loader): x = x.to(device) label = label.to(device) y = model(x) ys.append((y, label)) test_acc = calc_accuracy(ys) print(f"Epoch: {epoch} [{n_samples}/{max_samples}] Loss: {loss.item():.6f} Acc(Train): {train_acc} Acc(Test): {test_acc}") return def main(mode): initialize_seed(123) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") reg_lambda = 0.0 update_grid = True, if mode == "kan": model = ConvKAN(width=[20, 10], device=device) elif mode == "kan-no-update-grid": model = ConvKAN(width=[20, 10], device=device) update_grid = False elif mode == "kan-reg": model = ConvKAN(width=[20, 10], device=device) reg_lambda = 0.003 elif mode == "mlp": model = ConvMLP(fc_layers=[20, 10], device=device) else: return train_loader = DataLoader(datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=True) test_loader = DataLoader(datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor()), batch_size=128, shuffle=False) train(model, train_loader, test_loader, max_epoch=5, lamb=reg_lambda, update_grid=update_grid, device=device) if __name__ == "__main__": main(sys.argv[1]) |
ConvMLP
比較用としてMLPを使った単純なモデルで学習させてみた結果が以下です。
Epoch: 0 [128/60000] Loss: 2.330844 Epoch: 0 [12928/60000] Loss: 0.527279 Epoch: 0 [25728/60000] Loss: 0.264145 Epoch: 0 [38528/60000] Loss: 0.268547 Epoch: 0 [51328/60000] Loss: 0.270200 Epoch: 0 [60000/60000] Loss: 0.137060 Acc(Train): 82.13166809082031 Acc(Test): 93.66999816894531 Epoch: 1 [128/60000] Loss: 0.211759 Epoch: 1 [12928/60000] Loss: 0.172556 Epoch: 1 [25728/60000] Loss: 0.216685 Epoch: 1 [38528/60000] Loss: 0.224366 Epoch: 1 [51328/60000] Loss: 0.166860 Epoch: 1 [60000/60000] Loss: 0.297034 Acc(Train): 94.53333282470703 Acc(Test): 95.87000274658203 Epoch: 2 [128/60000] Loss: 0.169406 Epoch: 2 [12928/60000] Loss: 0.060187 Epoch: 2 [25728/60000] Loss: 0.065257 Epoch: 2 [38528/60000] Loss: 0.209189 Epoch: 2 [51328/60000] Loss: 0.137371 Epoch: 2 [60000/60000] Loss: 0.062734 Acc(Train): 95.69499969482422 Acc(Test): 96.29000091552734 Epoch: 3 [128/60000] Loss: 0.161130 Epoch: 3 [12928/60000] Loss: 0.094211 Epoch: 3 [25728/60000] Loss: 0.137475 Epoch: 3 [38528/60000] Loss: 0.143321 Epoch: 3 [51328/60000] Loss: 0.086744 Epoch: 3 [60000/60000] Loss: 0.295325 Acc(Train): 96.2300033569336 Acc(Test): 96.52999877929688 Epoch: 4 [128/60000] Loss: 0.084705 Epoch: 4 [12928/60000] Loss: 0.114923 Epoch: 4 [25728/60000] Loss: 0.071916 Epoch: 4 [38528/60000] Loss: 0.093224 Epoch: 4 [51328/60000] Loss: 0.111265 Epoch: 4 [60000/60000] Loss: 0.049419 Acc(Train): 96.69833374023438 Acc(Test): 97.1500015258789
ConvKAN
KANを使ったモデルで学習させてみた結果が以下です。Conv2Dが色々吸収してしまっているのかもしれませんが、違いがほとんどありません。 パラメータ数はConvMLPよりConvKANのほうが多いです。
Epoch: 0 [128/60000] Loss: 2.308975 Epoch: 0 [12928/60000] Loss: 0.608790 Epoch: 0 [25728/60000] Loss: 0.333690 Epoch: 0 [38528/60000] Loss: 0.214927 Epoch: 0 [51328/60000] Loss: 0.171603 Epoch: 0 [60000/60000] Loss: 0.099660 Acc(Train): 88.74166870117188 Acc(Test): 96.1500015258789 Epoch: 1 [128/60000] Loss: 0.073898 Epoch: 1 [12928/60000] Loss: 0.215137 Epoch: 1 [25728/60000] Loss: 0.109934 Epoch: 1 [38528/60000] Loss: 0.126619 Epoch: 1 [51328/60000] Loss: 0.066091 Epoch: 1 [60000/60000] Loss: 0.188135 Acc(Train): 96.20833587646484 Acc(Test): 96.88999938964844 Epoch: 2 [128/60000] Loss: 0.061085 Epoch: 2 [12928/60000] Loss: 0.078620 Epoch: 2 [25728/60000] Loss: 0.045636 Epoch: 2 [38528/60000] Loss: 0.052172 Epoch: 2 [51328/60000] Loss: 0.037537 Epoch: 2 [60000/60000] Loss: 0.149782 Acc(Train): 96.87333679199219 Acc(Test): 95.19000244140625 Epoch: 3 [128/60000] Loss: 0.040984 Epoch: 3 [12928/60000] Loss: 0.102282 Epoch: 3 [25728/60000] Loss: 0.017132 Epoch: 3 [38528/60000] Loss: 0.043684 Epoch: 3 [51328/60000] Loss: 0.126490 Epoch: 3 [60000/60000] Loss: 0.057794 Acc(Train): 97.1483383178711 Acc(Test): 97.47000122070312 Epoch: 4 [128/60000] Loss: 0.056742 Epoch: 4 [12928/60000] Loss: 0.087390 Epoch: 4 [25728/60000] Loss: 0.046712 Epoch: 4 [38528/60000] Loss: 0.058726 Epoch: 4 [51328/60000] Loss: 0.217805 Epoch: 4 [60000/60000] Loss: 0.137546 Acc(Train): 97.54500579833984 Acc(Test): 97.52999877929688
ConvKANでgridの更新なし
KANを使ったモデルでgridの更新なしで学習させてみた結果が以下です。B-スプライン曲線による活性化関数は処理できる入力値の範囲が決まっており、gridの更新なしというのは、その範囲の調整を行わないということです。今回の設定では特に効果が無いようです。
Epoch: 0 [128/60000] Loss: 2.308975 Epoch: 0 [12928/60000] Loss: 0.502888 Epoch: 0 [25728/60000] Loss: 0.330806 Epoch: 0 [38528/60000] Loss: 0.242808 Epoch: 0 [51328/60000] Loss: 0.165698 Epoch: 0 [60000/60000] Loss: 0.191306 Acc(Train): 86.51166534423828 Acc(Test): 95.06999969482422 Epoch: 1 [128/60000] Loss: 0.090638 Epoch: 1 [12928/60000] Loss: 0.238503 Epoch: 1 [25728/60000] Loss: 0.128466 Epoch: 1 [38528/60000] Loss: 0.166372 Epoch: 1 [51328/60000] Loss: 0.120421 Epoch: 1 [60000/60000] Loss: 0.113729 Acc(Train): 95.87166595458984 Acc(Test): 96.37999725341797 Epoch: 2 [128/60000] Loss: 0.094055 Epoch: 2 [12928/60000] Loss: 0.101774 Epoch: 2 [25728/60000] Loss: 0.053376 Epoch: 2 [38528/60000] Loss: 0.050028 Epoch: 2 [51328/60000] Loss: 0.049892 Epoch: 2 [60000/60000] Loss: 0.045015 Acc(Train): 96.88333129882812 Acc(Test): 96.44000244140625 Epoch: 3 [128/60000] Loss: 0.091558 Epoch: 3 [12928/60000] Loss: 0.116339 Epoch: 3 [25728/60000] Loss: 0.035658 Epoch: 3 [38528/60000] Loss: 0.047689 Epoch: 3 [51328/60000] Loss: 0.121902 Epoch: 3 [60000/60000] Loss: 0.066078 Acc(Train): 97.5 Acc(Test): 97.30999755859375 Epoch: 4 [128/60000] Loss: 0.069760 Epoch: 4 [12928/60000] Loss: 0.048372 Epoch: 4 [25728/60000] Loss: 0.042325 Epoch: 4 [38528/60000] Loss: 0.073130 Epoch: 4 [51328/60000] Loss: 0.130568 Epoch: 4 [60000/60000] Loss: 0.046441 Acc(Train): 97.72833251953125 Acc(Test): 97.43999481201172
ConvKANで正則化あり
KANを使ったモデルで正則化ありで実行してみます。正則化のロスの重み\(\lambda\)は0.003にしています。
Epoch: 0 [128/60000] Loss: 2.449388 Epoch: 0 [12928/60000] Loss: 0.862422 Epoch: 0 [25728/60000] Loss: 0.502983 Epoch: 0 [38528/60000] Loss: 0.569132 Epoch: 0 [51328/60000] Loss: 0.391648 Epoch: 0 [60000/60000] Loss: 0.360877 Acc(Train): 89.41166687011719 Acc(Test): 95.72999572753906 Epoch: 1 [128/60000] Loss: 0.345524 Epoch: 1 [12928/60000] Loss: 0.449618 Epoch: 1 [25728/60000] Loss: 0.375242 Epoch: 1 [38528/60000] Loss: 0.334639 Epoch: 1 [51328/60000] Loss: 0.313511 Epoch: 1 [60000/60000] Loss: 0.374076 Acc(Train): 96.25166320800781 Acc(Test): 96.95999908447266 Epoch: 2 [128/60000] Loss: 0.237053 Epoch: 2 [12928/60000] Loss: 0.721758 Epoch: 2 [25728/60000] Loss: 0.468227 Epoch: 2 [38528/60000] Loss: 0.489218 Epoch: 2 [51328/60000] Loss: 0.366972 Epoch: 2 [60000/60000] Loss: 0.461017 Acc(Train): 90.59166717529297 Acc(Test): 94.61000061035156 Epoch: 3 [128/60000] Loss: 0.391373 Epoch: 3 [12928/60000] Loss: 0.405158 Epoch: 3 [25728/60000] Loss: 0.296408 Epoch: 3 [38528/60000] Loss: 0.306373 Epoch: 3 [51328/60000] Loss: 0.362922 Epoch: 3 [60000/60000] Loss: 0.335211 Acc(Train): 94.8933334350586 Acc(Test): 95.55999755859375 Epoch: 4 [128/60000] Loss: 0.343693 Epoch: 4 [12928/60000] Loss: 0.300828 Epoch: 4 [25728/60000] Loss: 0.310633 Epoch: 4 [38528/60000] Loss: 0.402329 Epoch: 4 [51328/60000] Loss: 0.401533 Epoch: 4 [60000/60000] Loss: 0.272522 Acc(Train): 95.64000701904297 Acc(Test): 95.79000091552734\(\lambda=0.005\)にすると、以下のようになり途中でモデルが崩壊しました。
Epoch: 0 [128/60000] Loss: 2.542997 Epoch: 0 [12928/60000] Loss: 0.992837 Epoch: 0 [25728/60000] Loss: 0.666648 Epoch: 0 [38528/60000] Loss: 0.642962 Epoch: 0 [51328/60000] Loss: 0.496544 Epoch: 0 [60000/60000] Loss: 0.467141 Acc(Train): 88.85499572753906 Acc(Test): 95.80999755859375 Epoch: 1 [128/60000] Loss: 0.465318 Epoch: 1 [12928/60000] Loss: 0.551555 Epoch: 1 [25728/60000] Loss: 0.480575 Epoch: 1 [38528/60000] Loss: 0.532717 Epoch: 1 [51328/60000] Loss: 0.389326 Epoch: 1 [60000/60000] Loss: 0.464731 Acc(Train): 95.5183334350586 Acc(Test): 96.31999969482422 Epoch: 2 [128/60000] Loss: 0.360698 Epoch: 2 [12928/60000] Loss: 0.528383 Epoch: 2 [25728/60000] Loss: 0.377956 Epoch: 2 [38528/60000] Loss: 0.333951 Epoch: 2 [51328/60000] Loss: 0.379329 Epoch: 2 [60000/60000] Loss: 0.398343 Acc(Train): 94.8566665649414 Acc(Test): 96.06999969482422 Epoch: 3 [128/60000] Loss: 0.374904 Epoch: 3 [12928/60000] Loss: 0.928802 Epoch: 3 [25728/60000] Loss: 0.401091 Epoch: 3 [38528/60000] Loss: 0.539629 Epoch: 3 [51328/60000] Loss: 223.694611 Epoch: 3 [60000/60000] Loss: 3.871729 Acc(Train): 76.33499908447266 Acc(Test): 10.520000457763672 Epoch: 4 [128/60000] Loss: 4.122091 Epoch: 4 [12928/60000] Loss: 3.226183 Epoch: 4 [25728/60000] Loss: 2.973773 Epoch: 4 [38528/60000] Loss: 2.831389 Epoch: 4 [51328/60000] Loss: 2.789084 Epoch: 4 [60000/60000] Loss: 2.598794 Acc(Train): 9.819999694824219 Acc(Test): 10.09999942779541
0 件のコメント :
コメントを投稿