in example/mnist_with_meterlogger.py [0:0]
def main():
params = {
'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
}
params = {k: Variable(v, requires_grad=True) for k, v in params.items()}
optimizer = torch.optim.SGD(
params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)
engine = Engine()
mlog = MeterLogger(nclass=10, title="mnist_meterlogger")
def h(sample):
inputs = Variable(sample[0].float() / 255.0)
targets = Variable(torch.LongTensor(sample[1]))
o = f(params, inputs, sample[2])
return F.cross_entropy(o, targets), o
def on_sample(state):
state['sample'].append(state['train'])
def on_forward(state):
loss = state['loss']
output = state['output']
target = state['sample'][1]
# online ploter
mlog.update_loss(loss, meter='loss')
mlog.update_meter(output, target, meters={'accuracy', 'map', 'confusion'})
def on_start_epoch(state):
mlog.timer.reset()
state['iterator'] = tqdm(state['iterator'])
def on_end_epoch(state):
mlog.print_meter(mode="Train", iepoch=state['epoch'])
mlog.reset_meter(mode="Train", iepoch=state['epoch'])
# do validation at the end of each epoch
engine.test(h, get_iterator(False))
mlog.print_meter(mode="Test", iepoch=state['epoch'])
mlog.reset_meter(mode="Test", iepoch=state['epoch'])
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch
engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)