in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def train(model, optimizer, lr_scheduler,
train_data, val_data, timers, args,
summary_writer=None, hooks={}):
"""Train the model."""
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_lm_loss = 0.0
total_metrics = defaultdict(float)
total_metrics_cnt = defaultdict(int)
# Iterations.
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while args.iteration < args.train_iters:
if args.profiling != -1 and args.iteration == args.profiling:
torch.cuda.cudart().cudaProfilerStart()
if args.profiling != -1 and args.iteration >= args.profiling:
torch.cuda.nvtx.range_push("iteration{}".format(args.iteration))
lm_loss, skipped_iter, metrics = train_step(train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers, hooks=hooks)
skipped_iters += skipped_iter
if args.profiling != -1 and args.iteration >= args.profiling:
torch.cuda.nvtx.range_pop()
args.iteration += 1
# Update losses.
total_lm_loss += lm_loss.data.detach().float()
for name in metrics:
if not 'eval' in name:
assert len(metrics[name].shape)==0, 'metrics without eval must be scalar'
value = metrics[name].data.detach().float().item()
if value > -99:
total_metrics[name] += value
total_metrics_cnt[name] += 1
# Logging.
if args.iteration % args.log_interval == 0:
learning_rate = optimizer.param_groups[0]['lr']
avg_lm_loss = total_lm_loss.item() / args.log_interval
# average img & txt loss
avg_metrics = {}
for key in total_metrics:
avg_metrics[key] = total_metrics[key] / total_metrics_cnt[key] # args.log_interval
elapsed_time = timers('interval time').elapsed()
report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
avg_metrics)
total_lm_loss = 0.0
total_metrics = defaultdict(float)
total_metrics_cnt = defaultdict(int)
if report_memory_flag:
report_memory('after {} iterations'.format(args.iteration))
report_memory_flag = False
timers.log(['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader'],
normalizer=args.log_interval)
# Checkpointing
if args.save and args.save_interval and args.iteration % args.save_interval == 0:
save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
if args.strict_eval:
val_data_iterator = iter(val_data)
eval_iters = len(val_data)
else:
eval_iters = args.eval_iters
prefix = 'iteration {}'.format(args.iteration)
evaluate_and_print_results(
prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, split='val', summary_writer=summary_writer, hooks=hooks)
if args.exit_interval and args.iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print_all('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, args.iteration), flush=True)
exit()
if args.profiling != -1:
torch.cuda.cudart().cudaProfilerStop()
return args.iteration, skipped_iters