in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def train_step(data_iterator, model, optimizer, lr_scheduler,
args, timers, hooks=None, single_step=False, **kwargs):
"""Single training step."""
if hooks is None:
hooks = {}
lm_loss_total, metrics_total, count, metrics_count = 0.0, {}, 0, {}
forward_step = hooks['forward_step']
while True:
profiling_flag = (args.profiling != -1 and args.iteration >= args.profiling)
# Forward model for one step.
if profiling_flag:
torch.cuda.nvtx.range_push("forward")
timers('forward').start()
forward_ret = forward_step(data_iterator, model, args, timers, **kwargs)
if isinstance(forward_ret, tuple):
lm_loss, metrics = forward_ret
else:
lm_loss, metrics = forward_ret, {}
timers('forward').stop()
if profiling_flag:
torch.cuda.nvtx.range_pop()
# Check nan or inf in forward, preventing it from interfering loss scaler,
# and all reduce metrics by the way
if profiling_flag:
torch.cuda.nvtx.range_push("loss_and_metrics")
lm_loss_reduced = lm_loss.detach().clone()
torch.distributed.all_reduce(lm_loss_reduced.data)
lm_loss_reduced.data = lm_loss_reduced.data / args.world_size
loss_checker = lm_loss_reduced
for name in metrics:
if not 'eval' in name:
metrics[name] = metrics[name].detach().clone()
if metrics[name].data.item() == -100:
cnt = torch.zeros(1, dtype=torch.int64, device=metrics[name].data.device)
metrics[name].data = torch.tensor(0., device=metrics[name].data.device)
else:
cnt = torch.ones(1, dtype=torch.int64, device=metrics[name].data.device)
torch.distributed.all_reduce(metrics[name].data)
torch.distributed.all_reduce(cnt)
if cnt.item() == 0:
metrics[name].data = torch.tensor(-100, device=metrics[name].data.device)
else:
metrics[name].data /= cnt.cpu().item() # args.world_size
loss_checker = loss_checker + metrics[name]
if loss_checker.isnan().any() or loss_checker.isinf().any():
print_all('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics
# Accumulate the statistics
lm_loss_total += lm_loss_reduced
for name in metrics:
if name not in metrics_total:
metrics_total[name] = torch.tensor(0.0, device=metrics[name].data.device)
if name not in metrics_count:
metrics_count[name] = 0
if metrics[name].data.item() != -100:
metrics_total[name] += metrics[name]
metrics_count[name] += 1
count += 1
if profiling_flag:
torch.cuda.nvtx.range_pop()
if profiling_flag:
torch.cuda.nvtx.range_push("backward")
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, lm_loss, args, timers)
timers('backward').stop()
if profiling_flag:
torch.cuda.nvtx.range_pop()
# Update parameters.
skipped_iter, complete = 0, False
if profiling_flag:
torch.cuda.nvtx.range_push("optimizer")
timers('optimizer').start()
if args.deepspeed:
if model.is_gradient_accumulation_boundary():
model.step()
complete = True
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
else:
model.step()
else:
raise ValueError('Currently, we only support training with deepspeed.')
timers('optimizer').stop()
if profiling_flag:
torch.cuda.nvtx.range_pop()
if complete or single_step:
break
lm_loss_total /= count
metrics_total = {key: torch.tensor(-100, device=metrics_total[key].data.device) if metrics_count[key] == 0 else value / metrics_count[key] for key, value in metrics_total.items()}
return lm_loss_total, skipped_iter, metrics_total