ViT4MNIST/mnist_vit.py (73 lines of code) (raw):
import torch
import torchvision
import time
from vit_pytorch import *
from torch.utils.mobile_optimizer import optimize_for_mobile
torch.manual_seed(42)
DOWNLOAD_PATH = 'data/mnist'
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 1000
# 0.1307 and 0.3081 are the mean and std computed on the MNIST training set
transform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True,
transform=transform_mnist)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True,
transform=transform_mnist)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE_TEST, shuffle=True)
def train_epoch(model, optimizer, data_loader, loss_history):
total_samples = len(data_loader.dataset)
model.train()
for i, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if i % 100 == 0:
print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) +
' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)] Loss: ' +
'{:6.4f}'.format(loss.item()))
loss_history.append(loss.item())
def evaluate(model, data_loader, loss_history):
model.eval()
total_samples = len(data_loader.dataset)
correct_samples = 0
total_loss = 0
with torch.no_grad():
for data, target in data_loader:
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target, reduction='sum')
_, pred = torch.max(output, dim=1)
total_loss += loss.item()
correct_samples += pred.eq(target).sum()
avg_loss = total_loss / total_samples
loss_history.append(avg_loss)
print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) +
' Accuracy:' + '{:5}'.format(correct_samples) + '/' +
'{:5}'.format(total_samples) + ' (' +
'{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n')
N_EPOCHS = 25
start_time = time.time()
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
dim=64, depth=6, heads=8, mlp_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
train_loss_history, test_loss_history = [], []
for epoch in range(1, N_EPOCHS + 1):
print('Epoch:', epoch)
train_epoch(model, optimizer, train_loader, train_loss_history)
evaluate(model, test_loader, test_loss_history)
print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')
with torch.no_grad():
for data, target in test_loader:
output = F.log_softmax(model(data), dim=1)
loss = F.nll_loss(output, target, reduction='sum')
_, pred = torch.max(output, dim=1)
# the original trained model
torch.save(model, "vit4mnist.pt")
model = torch.load("vit4mnist.pt")
model.eval()
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
dummy_input = torch.zeros(1, 1, 28, 28)
ts_model = torch.jit.trace(quantized_model, dummy_input)
optimized_torchscript_model = optimize_for_mobile(ts_model)
# the quantized, scripted, and optimized model
optimized_torchscript_model._save_for_lite_interpreter("app/src/main/assets/vit4mnist.ptl")