ViT4MNIST/mnist_vit.py (73 lines of code) (raw):

import torch import torchvision import time from vit_pytorch import * from torch.utils.mobile_optimizer import optimize_for_mobile torch.manual_seed(42) DOWNLOAD_PATH = 'data/mnist' BATCH_SIZE_TRAIN = 100 BATCH_SIZE_TEST = 1000 # 0.1307 and 0.3081 are the mean and std computed on the MNIST training set transform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]) train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist) train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True) test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist) test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE_TEST, shuffle=True) def train_epoch(model, optimizer, data_loader, loss_history): total_samples = len(data_loader.dataset) model.train() for i, (data, target) in enumerate(data_loader): optimizer.zero_grad() output = F.log_softmax(model(data), dim=1) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if i % 100 == 0: print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) + ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)] Loss: ' + '{:6.4f}'.format(loss.item())) loss_history.append(loss.item()) def evaluate(model, data_loader, loss_history): model.eval() total_samples = len(data_loader.dataset) correct_samples = 0 total_loss = 0 with torch.no_grad(): for data, target in data_loader: output = F.log_softmax(model(data), dim=1) loss = F.nll_loss(output, target, reduction='sum') _, pred = torch.max(output, dim=1) total_loss += loss.item() correct_samples += pred.eq(target).sum() avg_loss = total_loss / total_samples loss_history.append(avg_loss) print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) + ' Accuracy:' + '{:5}'.format(correct_samples) + '/' + '{:5}'.format(total_samples) + ' (' + '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n') N_EPOCHS = 25 start_time = time.time() model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1, dim=64, depth=6, heads=8, mlp_dim=128) optimizer = torch.optim.Adam(model.parameters(), lr=0.003) train_loss_history, test_loss_history = [], [] for epoch in range(1, N_EPOCHS + 1): print('Epoch:', epoch) train_epoch(model, optimizer, train_loader, train_loss_history) evaluate(model, test_loader, test_loss_history) print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds') with torch.no_grad(): for data, target in test_loader: output = F.log_softmax(model(data), dim=1) loss = F.nll_loss(output, target, reduction='sum') _, pred = torch.max(output, dim=1) # the original trained model torch.save(model, "vit4mnist.pt") model = torch.load("vit4mnist.pt") model.eval() quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) dummy_input = torch.zeros(1, 1, 28, 28) ts_model = torch.jit.trace(quantized_model, dummy_input) optimized_torchscript_model = optimize_for_mobile(ts_model) # the quantized, scripted, and optimized model optimized_torchscript_model._save_for_lite_interpreter("app/src/main/assets/vit4mnist.ptl")