torchbenchmark/models/maml/__init__.py (74 lines of code) (raw):
import numpy as np
import random
import time
import torch
from argparse import Namespace
from .meta import Meta
from pathlib import Path
from typing import Tuple
from ...util.model import BenchmarkModel
from torchbenchmark.tasks import OTHER
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False
def __init__(self, test, device, jit, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
# load from disk or synthesize data
use_data_file = False
debug_print = False
root = str(Path(__file__).parent)
args = Namespace(**{
'n_way': 5,
'k_spt': 1,
'k_qry': 15,
'imgsz': 28,
'imgc': 1,
'task_num': 32,
'meta_lr': 1e-3,
'update_lr': 0.4,
'update_step': 5,
'update_step_test': 10
})
config = [
('conv2d', [64, args.imgc, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 2, 2, 1, 0]),
('relu', [True]),
('bn', [64]),
('flatten', []),
('linear', [args.n_way, 64])
]
self.module = Meta(args, config).to(device)
if use_data_file:
self.example_inputs = torch.load(f'{root}/batch.pt')
self.example_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.example_inputs])
else:
# synthesize data parameterized by arg values
self.example_inputs = (
torch.randn(args.task_num, args.n_way, args.imgc, args.imgsz, args.imgsz).to(device),
torch.randint(0, args.n_way, [args.task_num, args.n_way], dtype=torch.long).to(device),
torch.randn(args.task_num, args.n_way * args.k_qry, args.imgc, args.imgsz, args.imgsz).to(device),
torch.randint(0, args.n_way, [args.task_num, args.n_way * args.k_qry], dtype=torch.long).to(device))
# print input shapes
if debug_print:
for i in range(len(self.example_inputs)):
print(self.example_inputs[i].shape)
def get_module(self):
return self.module, self.example_inputs
def eval(self, niter=1) -> Tuple[torch.Tensor]:
for _ in range(niter):
out = self.module(*self.example_inputs)
return (out, )
def train(self, niter=1):
for _ in range(niter):
self.module(*self.example_inputs)
def eval_in_nograd(self):
return False