in torchnet/engine/engine.py [0:0]
def train(self, network, iterator, maxepoch, optimizer):
state = {
'network': network,
'iterator': iterator,
'maxepoch': maxepoch,
'optimizer': optimizer,
'epoch': 0,
't': 0,
'train': True,
}
self.hook('on_start', state)
while state['epoch'] < state['maxepoch']:
self.hook('on_start_epoch', state)
for sample in state['iterator']:
state['sample'] = sample
self.hook('on_sample', state)
def closure():
loss, output = state['network'](state['sample'])
state['output'] = output
state['loss'] = loss
loss.backward()
self.hook('on_forward', state)
# to free memory in save_for_backward
state['output'] = None
state['loss'] = None
return loss
state['optimizer'].zero_grad()
state['optimizer'].step(closure)
self.hook('on_update', state)
state['t'] += 1
state['epoch'] += 1
self.hook('on_end_epoch', state)
self.hook('on_end', state)
return state