in optimum/neuron/trainers.py [0:0]
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
# We always reduce the loss, even when we do not use it to avoid a new graph.
# This communication is not costly.
if self.state.global_step > self._globalstep_last_logged:
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_replica_groups,
get_data_parallel_size,
model_parallel_is_initialized,
)
if model_parallel_is_initialized():
dp_size = get_data_parallel_size()
else:
dp_size = xr.world_size()
tr_loss_div = tr_loss / dp_size
reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_replica_groups())
reduced_tr_loss = reduced_tr_loss.detach()
if self.control.should_log:
xm.mark_step()
tr_loss.zero_()
def log_closure(self, reduced_tr_loss, grad_norm):
# We need to check that self.state.global_step > self._globalstep_last_logged because if two
# closures are added in a row (which can happen at the end of the training), then it will fail the
# second time because at this point we will have:
# self.state.global_step = self._globalstep_last_logged
if is_main_worker_for_metrics() and self.state.global_step > self._globalstep_last_logged:
logs: Dict[str, float] = {}
tr_loss_scalar = reduced_tr_loss.to("cpu").item()
logs["loss"] = round(
tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4
)
logs["learning_rate"] = self._get_learning_rate()
if grad_norm is not None:
logs["grad_norm"] = (
grad_norm.detach().to("cpu").item()
if isinstance(grad_norm, torch.Tensor)
else grad_norm
)
self._total_loss_scalar += tr_loss_scalar
self.store_flos()
self.log(logs, start_time)
self._globalstep_last_logged = self.state.global_step
xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
if self.args.save_strategy == SaveStrategy.BEST:
self.control.should_save = is_new_best_metric
if self.control.should_save:
def save_closure(self, model, trial):
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
xm.add_step_closure(save_closure, (self, model, trial))