in example/mnist.py [0:0]
def main():
params = {
'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
}
params = {k: Variable(v, requires_grad=True) for k, v in params.items()}
optimizer = torch.optim.SGD(
params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
classerr = tnt.meter.ClassErrorMeter(accuracy=True)
def h(sample):
inputs = Variable(sample[0].float() / 255.0)
targets = Variable(torch.LongTensor(sample[1]))
o = f(params, inputs, sample[2])
return F.cross_entropy(o, targets), o
def reset_meters():
classerr.reset()
meter_loss.reset()
def on_sample(state):
state['sample'].append(state['train'])
def on_forward(state):
classerr.add(state['output'].data,
torch.LongTensor(state['sample'][1]))
meter_loss.add(state['loss'].data[0])
def on_start_epoch(state):
reset_meters()
state['iterator'] = tqdm(state['iterator'])
def on_end_epoch(state):
print('Training loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))
# do validation at the end of each epoch
reset_meters()
engine.test(h, get_iterator(False))
print('Testing loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch
engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)