PyTorchでMNISTを動かしてみました。CPUのみで動作します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | #
# Ref: https://github.com/pytorch/examples/blob/master/mnist/main.py
# Ref: https://qiita.com/ryu1104/items/76126a1d2ce22c59fe97
#
# Requirements:
# pyenvでpythonをインストールするときはliblzma-devが必要。
# pip install pylzma
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 64, (5,5), stride=(2,2))
self.conv2 = nn.Conv2d(64, 128, (5,5), stride=(2,2))
self.linear1 = nn.Linear(2048, 256)
self.linear2 = nn.Linear(256, 10)
self.dropout1 = nn.Dropout(0.5)
def forward(self, x):
x = F.leaky_relu(self.conv1(x), negative_slope=0.02)
x = F.leaky_relu(self.conv2(x), negative_slope=0.02)
x = torch.flatten(x, 1)
x = self.dropout1(F.leaky_relu(self.linear1(x), negative_slope=0.02))
x = self.linear2(x)
return x
def main():
torch.manual_seed(123)
dataset_train = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
dataset_test = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())
loader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=128)
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
loss_train = 0
for batch_index, (x, t) in enumerate(loader_train):
opt.zero_grad()
y = model(x)
loss = F.cross_entropy(y, t, reduction="sum")
loss_train += loss.item()
loss.backward()
opt.step()
loss_train /= len(loader_train.dataset)
# Test
model.eval()
loss_test = 0
correct = 0
with torch.no_grad():
for x, t in loader_test:
y = model(x)
loss_test += F.cross_entropy(y, t, reduction="sum").item()
pred = y.argmax(dim=1, keepdim=True)
a = t.view_as(pred)
correct += pred.eq(a).sum().item()
loss_test /= len(loader_test.dataset)
acc_test = 100.0 * correct / len(loader_test.dataset)
print("epoch={} loss_train={} loss_test={} acc_test={}".format(epoch, loss_train, loss_test, acc_test))
if __name__ == "__main__":
main()
|
epoch=0 loss_train=0.15844096369811644 loss_test=0.04875367822442204 acc_test=98.39 epoch=1 loss_train=0.05745177720999345 loss_test=0.03826701421057806 acc_test=98.81 epoch=2 loss_train=0.04121023142867489 loss_test=0.03071835657870397 acc_test=98.93 epoch=3 loss_train=0.03073415454996381 loss_test=0.031469110992277276 acc_test=99.05 epoch=4 loss_train=0.024984311143002317 loss_test=0.033686953871918376 acc_test=99.04 epoch=5 loss_train=0.02134333044563282 loss_test=0.04148742442613293 acc_test=98.79 epoch=6 loss_train=0.017344313688603386 loss_test=0.043801980772903655 acc_test=98.99 epoch=7 loss_train=0.015290370488148755 loss_test=0.04075671738231176 acc_test=99.09 epoch=8 loss_train=0.0152512503207066 loss_test=0.04280102529609985 acc_test=99.0 epoch=9 loss_train=0.015672046943081695 loss_test=0.043737064159646434 acc_test=98.98となります。
0 件のコメント :
コメントを投稿