def _log_training()

in vision/m4/training/trainer.py [0:0]


    def _log_training(self, curr_opt_step, train_task, train_logs):
        for key in train_logs["per_token_loss_acc"].keys():
            if train_logs["num_per_device_batches_since_training_logged"][key] is not None:
                train_logs["per_token_loss"][key] = (
                    train_logs["per_token_loss_acc"][key]
                    / train_logs["num_per_device_batches_since_training_logged"][key]
                )
                train_logs["z_loss"][key] = (
                    train_logs["z_loss_acc"][key] / train_logs["num_per_device_batches_since_training_logged"][key]
                )
            else:
                train_logs["per_token_loss"][key] = None
                train_logs["z_loss"][key] = None

            if train_logs["fwd_bwd_time_since_training_logged"][key] is not None:
                train_logs["tflops"][key] = (
                    train_logs["tflop_counter_since_training_logged"][key]
                    / train_logs["fwd_bwd_time_since_training_logged"][key]
                )
                train_logs["watt/s"][key] = (
                    train_logs["total_energy_delta_since_training_logged"][key]
                    / train_logs["fwd_bwd_time_since_training_logged"][key]
                )
            else:
                train_logs["tflops"][key] = None
                train_logs["watt/s"][key] = None

        if self.accelerator.is_main_process:
            print_log = ""
            progress = f"{str(MofNCompleteColumn().render(train_task)):>12} {TaskProgressColumn().render(train_task)}"
            print_log += f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] iteration: {progress} | "
            elapsed_time = TimeElapsedColumn().render(train_task)
            print_log += f"elapsed time: {elapsed_time} | "

            print_log += self.format_train_logs(train_logs, logger_type=LoggingTypes.PRINT)

            # TODO: Allow mem usage to be logged according to LogginTypes passed in hparams
            if self.hparams.train_log_mem_usage:
                print_log += mem_usage_formatted(LoggingTypes.PRINT)

            print(print_log)

            jsonl_logs = {
                "iteration": progress.strip(),
                "elapsed_time": str(elapsed_time),
                "set": "train",
            }
            jsonl_logs.update(self.format_train_logs(train_logs, logger_type=LoggingTypes.JSONL))
            if self.hparams.train_log_mem_usage:
                jsonl_logs.update(mem_usage_formatted(LoggingTypes.JSONL))

            if self.hparams.job_id is not None:
                log_jsonl_file = self.hparams.save_dir / "logs" / f"{self.hparams.job_id}_logs.jsonl"
            else:
                log_jsonl_file = self.hparams.save_dir / "logs" / "logs.jsonl"

            log_jsonl_file.parent.mkdir(parents=True, exist_ok=True)

            with open(log_jsonl_file, "a") as f:
                f.write(json.dumps(jsonl_logs) + "\n")

            if self.hparams.wandb_enable:
                filtered_train_logs = train_logs
                if LoggingTypes.WANDB not in self.hparams.train_logging_per_dataset_info:
                    filtered_train_logs = {}
                    for key in train_logs.keys():
                        if isinstance(train_logs[key], dict):
                            filtered_train_logs[key] = train_logs[key]["all"]
                        else:
                            filtered_train_logs[key] = train_logs[key]
                # remove nested None values as wandb doesn't support them
                filtered_train_logs = {k: v for k, v in filtered_train_logs.items() if v is not None}
                for k, v in filtered_train_logs.items():
                    if isinstance(v, dict):
                        filtered_train_logs[k] = {k2: v2 for k2, v2 in v.items() if v2 is not None}
                self.accelerator.log({**filtered_train_logs, **self._get_additional_step_logs()}, step=curr_opt_step)

        train_logs = self._reset_train_logs(train_logs)
        return train_logs