def main()

in example/mnist_with_meterlogger.py [0:0]


def main():
    params = {
        'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
        'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
        'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
        'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
    }
    params = {k: Variable(v, requires_grad=True) for k, v in params.items()}

    optimizer = torch.optim.SGD(
        params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)

    engine = Engine()

    mlog = MeterLogger(nclass=10, title="mnist_meterlogger")

    def h(sample):
        inputs = Variable(sample[0].float() / 255.0)
        targets = Variable(torch.LongTensor(sample[1]))
        o = f(params, inputs, sample[2])
        return F.cross_entropy(o, targets), o

    def on_sample(state):
        state['sample'].append(state['train'])

    def on_forward(state):
        loss = state['loss']
        output = state['output']
        target = state['sample'][1]
        # online ploter
        mlog.update_loss(loss, meter='loss')
        mlog.update_meter(output, target, meters={'accuracy', 'map', 'confusion'})

    def on_start_epoch(state):
        mlog.timer.reset()
        state['iterator'] = tqdm(state['iterator'])

    def on_end_epoch(state):
        mlog.print_meter(mode="Train", iepoch=state['epoch'])
        mlog.reset_meter(mode="Train", iepoch=state['epoch'])

        # do validation at the end of each epoch
        engine.test(h, get_iterator(False))
        mlog.print_meter(mode="Test", iepoch=state['epoch'])
        mlog.reset_meter(mode="Test", iepoch=state['epoch'])

    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)