in train/model.py [0:0]
def fit(self, train_iter, optimizer, lr_scheduler,
eval_iter=None,
metrics=metric.Accuracy(topk=1),
epoch_start=0,
epoch_end=10000,
precise_bn=False,
precise_bn_steps=500,
epoch_div_factor=(torch.distributed.get_world_size() if torch.distributed._initialized else 1),
**kwargs):
"""
checking
"""
if kwargs:
logging.warning("Unknown kwargs: {}".format(kwargs))
assert torch.cuda.is_available(), "only support GPU version"
"""
setup iterator
"""
precise_bn_steps = 0 if not precise_bn else precise_bn_steps
epoch_freeze_step = int(round(0.2*precise_bn_steps))
epoch_train_steps = int(train_iter.batch_sampler.__len__() / epoch_div_factor)
epoch_eval_steps = int(eval_iter.batch_sampler.__len__() / epoch_div_factor)
if (train_iter.batch_sampler.__len__() - epoch_train_steps) > precise_bn_steps:
# train iter is sufficient
epoch_term_steps = epoch_train_steps + precise_bn_steps
else:
epoch_term_steps = epoch_train_steps
epoch_train_steps = epoch_train_steps - precise_bn_steps
assert epoch_train_steps > 0, "train_steps < precise_bn_steps ({} v.s. {})".format(epoch_train_steps, precise_bn_steps)
logging.warning(">> using the last {} iter for computing the precise bathnorm.")
"""
start the main loop
"""
for i_epoch in range(epoch_start, epoch_end):
self.callback_kwargs['epoch'] = i_epoch
epoch_start_time = time.time()
###########
# 1] TRAINING
###########
metrics.reset()
self.net.train()
sum_batch_inst = 0
sum_batch_elapse = 0.
sum_update_elapse = 0
batch_start_time = time.time()
logging.info("Start epoch {:d}, iter stride {:d}, train steps {:d}, eval steps: {:d}".format( \
i_epoch, epoch_div_factor, epoch_train_steps, epoch_eval_steps))
for i_batch, (data, target) in enumerate(train_iter):
if i_batch >= epoch_term_steps:
break
if precise_bn and i_batch == epoch_train_steps:
logging.info("Compute precise batchnorm: {} to {}.".format(epoch_train_steps, epoch_term_steps))
# TODO: better way to rsync running_mean / runing_var
self.save_checkpoint(epoch=i_epoch+1, optimizer_state=optimizer.state_dict())
while not os.path.exists(self.get_checkpoint_path(epoch=i_epoch+1)):
time.sleep(1)
time.sleep(5)
self.load_checkpoint(epoch=i_epoch+1)
metrics.reset()
self.callback_kwargs['batch'] = i_batch
update_start_time = time.time()
# [forward] making next step
outputs, losses = self.forward(data, target)
# [backward]
if i_batch < epoch_train_steps:
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
optimizer.step()
elif i_batch < (epoch_term_steps - epoch_freeze_step):
# for precise bn (stage 1)
optimizer.zero_grad()
for loss in losses: loss.backward()
self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.get_lr())
optimizer.step(visiable=["precise.bn"])
else:
# for precise bn (stage 2)
# update running mean/var (done forward pass)
pass
self.callback_kwargs['lr'] = lr_scheduler.get_lr()
# [evaluation] update train metric
metrics.update([output.data.cpu() for output in outputs],
target.cpu(),
[loss.data.cpu() for loss in losses])
# timing each batch
sum_batch_elapse += time.time() - batch_start_time
sum_update_elapse += time.time() - update_start_time
sum_batch_inst += 1
if (i_batch % self.step_callback_freq) == 0:
name_value_prefix = 'tr-' if i_batch < epoch_train_steps else 'bn-'
self.callback_kwargs['namevals'] = metrics.get_name_value(prefix=name_value_prefix)
metrics.reset()
# speed monitor
self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
self.callback_kwargs['update_elapse'] = sum_update_elapse / sum_batch_inst
sum_update_elapse = 0
sum_batch_elapse = 0
sum_batch_inst = 0
# callbacks
self.step_end_callback()
# save checkpoint in case of unexpected interrupt
if (i_batch % 500) == 0 and i_batch < epoch_train_steps:
self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
self.epoch_end_callback()
# end of current train iter
batch_start_time = time.time()
###########
# 2] END OF EPOCH
###########
self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
self.epoch_end_callback()
###########
# 3] Evaluation
###########
if (eval_iter is not None) \
and ((i_epoch+1) % max(1, int(self.save_checkpoint_freq/2))) == 0:
logging.info("Start evaluating epoch {:d}:".format(i_epoch))
metrics.reset()
self.net.eval()
sum_batch_elapse = 0.
sum_batch_inst = 0
sum_forward_elapse = 0.
with torch.no_grad():
# if True:
batch_start_time = time.time()
for i_batch, (data, target) in enumerate(eval_iter):
forward_start_time = time.time()
outputs, losses = self.forward(data, target)
metrics.update([output.data.cpu() for output in outputs],
target.cpu(),
[loss.data.cpu() for loss in losses])
sum_forward_elapse += time.time() - forward_start_time
sum_batch_elapse += time.time() - batch_start_time
batch_start_time = time.time()
sum_batch_inst += 1
if i_batch >= epoch_eval_steps:
break
# evaluation callbacks
self.callback_kwargs['batch'] = sum_batch_inst
self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
self.callback_kwargs['update_elapse'] = sum_forward_elapse / sum_batch_inst
self.callback_kwargs['namevals'] = metrics.get_name_value(prefix='ts-')
self.step_end_callback()
logging.info("Optimization done!")