def _maybe_log_save_evaluate()

in optimum/neuron/trainers.py [0:0]


    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
        # We always reduce the loss, even when we do not use it to avoid a new graph.
        # This communication is not costly.

        if self.state.global_step > self._globalstep_last_logged:
            from neuronx_distributed.parallel_layers.parallel_state import (
                get_data_parallel_replica_groups,
                get_data_parallel_size,
                model_parallel_is_initialized,
            )

            if model_parallel_is_initialized():
                dp_size = get_data_parallel_size()
            else:
                dp_size = xr.world_size()

            tr_loss_div = tr_loss / dp_size

            reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_replica_groups())

            reduced_tr_loss = reduced_tr_loss.detach()

            if self.control.should_log:
                xm.mark_step()
                tr_loss.zero_()

                def log_closure(self, reduced_tr_loss, grad_norm):
                    # We need to check that self.state.global_step > self._globalstep_last_logged because if two
                    # closures are added in a row (which can happen at the end of the training), then it will fail the
                    # second time because at this point we will have:
                    # self.state.global_step = self._globalstep_last_logged
                    if is_main_worker_for_metrics() and self.state.global_step > self._globalstep_last_logged:
                        logs: Dict[str, float] = {}
                        tr_loss_scalar = reduced_tr_loss.to("cpu").item()

                        logs["loss"] = round(
                            tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4
                        )
                        logs["learning_rate"] = self._get_learning_rate()

                        if grad_norm is not None:
                            logs["grad_norm"] = (
                                grad_norm.detach().to("cpu").item()
                                if isinstance(grad_norm, torch.Tensor)
                                else grad_norm
                            )

                        self._total_loss_scalar += tr_loss_scalar
                        self.store_flos()
                        self.log(logs, start_time)

                    self._globalstep_last_logged = self.state.global_step

                xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))

        metrics = None
        if self.control.should_evaluate:
            metrics = self._evaluate(trial, ignore_keys_for_eval)
            is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)

            if self.args.save_strategy == SaveStrategy.BEST:
                self.control.should_save = is_new_best_metric

        if self.control.should_save:

            def save_closure(self, model, trial):
                self._save_checkpoint(model, trial)
                self.control = self.callback_handler.on_save(self.args, self.state, self.control)

            xm.add_step_closure(save_closure, (self, model, trial))