chatlearn/schedule/metric_manager.py (93 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """metric manager""" import traceback import wandb from torch.utils.tensorboard import SummaryWriter from chatlearn.utils.constant import LOG_START from chatlearn.utils.logger import logger class MetricManager: """Metric manager""" def __init__(self, global_args): self.global_args = global_args self.runtime_args = global_args.runtime_args self._setup_tensorboard() self._setup_wandb() self.writer_dict = {} if self.tensorboard_writer: self.writer_dict['tensorboard'] = self.tensorboard_writer if self.wandb_writer: self.writer_dict['wandb'] = self.wandb_writer def _setup_tensorboard(self): if ( self.runtime_args.log_args_dict is None or 'enable_tensorboard' not in self.runtime_args.log_args_dict or not self.runtime_args.log_args_dict['enable_tensorboard'] ): self.tensorboard_writer = None logger.info("tensorboard is disabled in engine.") return try: self.tensorboard_writer = SummaryWriter( log_dir=self.runtime_args.log_args_dict['tensorboard_dir'], max_queue=99999 ) except Exception: self.tensorboard_writer = None logger.warning(f"{LOG_START} setup tensorboard failed, tensorboard_writer is set to empty.") else: logger.info(f"{LOG_START} setup tensorboard success.") def _setup_wandb(self): if ( self.runtime_args.log_args_dict is None or 'enable_wandb' not in self.runtime_args.log_args_dict or not self.runtime_args.log_args_dict['enable_wandb'] ): self.wandb_writer = None logger.info("wandb is disabled in engine.") return try: wandb_kwargs = { 'dir': self.runtime_args.log_args_dict['wandb_dir'], 'project': self.runtime_args.log_args_dict['wandb_project'], 'id': self.runtime_args.log_args_dict['wandb_id'], 'name': self.runtime_args.log_args_dict['wandb_name'], 'resume': self.runtime_args.log_args_dict['wandb_resume'], 'config': self.global_args, } logger.info(f"WANDB_ARGS: {wandb_kwargs}") wandb.init(**wandb_kwargs) except Exception: traceback.print_exc() self.wandb_writer = None logger.warning(f"{LOG_START} setup wandb failed, wandb_writer is set to empty.") else: self.wandb_writer = wandb logger.info(f"{LOG_START} setup wandb success.") def log(self, prefix:str, global_step:int, scalar_dict): prefix = prefix.rstrip('/') logger.info(f"step {global_step} prefix {prefix}: logging metric {scalar_dict}") for writer_name, _ in self.writer_dict.items(): if writer_name == 'tensorboard': self._tensorboard_scalar_dict(prefix, global_step, scalar_dict) if writer_name == 'wandb': self._wandb_scalar_dict(prefix, global_step, scalar_dict) def _tensorboard_scalar_dict(self, prefix, global_step, scalar_dict): if isinstance(scalar_dict, (float, int)): name = prefix value = scalar_dict self.tensorboard_writer.add_scalar(name, value, global_step) else: for key, value in scalar_dict.items(): name = f"{prefix}/{key}".lstrip('/') self.tensorboard_writer.add_scalar(name, value, global_step) def _wandb_scalar_dict(self, prefix, global_step, scalar_dict): if isinstance(scalar_dict, (float, int)): name = prefix value = scalar_dict self.wandb_writer.log({f"{name}": value}, step=global_step) else: scalar_dict_with_prefix = {} for key, value in scalar_dict.items(): name = f"{prefix}/{key}".lstrip('/') scalar_dict_with_prefix[name] = value self.wandb_writer.log(scalar_dict_with_prefix, step=global_step) def stop(self): if self.wandb_writer: self.wandb_writer.finish()