in src/accelerate/utils/megatron_lm.py [0:0]
def forward(self, **batch_data):
# During training, we use train_step()
# model(**batch_data) performs following operations by delegating it to `self.train_step`:
# 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism
# 2. Set grad to zero.
# 3. forward pass and backward pass using Pipeline Parallelism
# 4. Empty unused memory.
# 5. Reduce gradients.
# 6. Update parameters.
# 7. Gather params when using Distributed Optimizer (Data Parallelism).
# 8. Update learning rate if scheduler is specified.
# 9. Empty unused memory.
# 10. Average loss across microbatches and across DP ranks.
#
# During evaluation, we use eval_step()
args = get_args()
if self.module[0].training:
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data)
self.iteration += 1
batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
args.consumed_train_samples += batch_size
self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size)
if args.tensorboard_dir is not None:
# Logging.
loss_scale = self.optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(self.model)
self.report_memory_flag = training_log(
loss_dict,
self.total_loss_dict,
self.optimizer.param_groups[0]["lr"],
self.iteration,
loss_scale,
self.report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
else:
loss_dict = self.eval_step(**batch_data)
if args.tensorboard_dir is not None:
for key in loss_dict:
self.eval_total_loss_dict[key] = (
self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
)
self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get(
key + "_num_iters", torch.cuda.FloatTensor([0.0])
) + torch.cuda.FloatTensor([1.0])
loss = torch.tensor(0.0, device=torch.cuda.current_device())
for key in loss_dict:
if len(loss_dict[key].shape) == 0:
loss += loss_dict[key]
logits = None
if "logits" in loss_dict:
logits = loss_dict["logits"]
if self.train_step_handler.model_output_class is not None:
return self.train_step_handler.model_output_class(loss=loss, logits=logits)
return loss