chatlearn/utils/logger.py (66 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """logging""" import logging import os import torch from tqdm import tqdm def setup_logger(log_level=None, model_name=None, ip_addr=None): logger_name = "ChatLearn" if model_name is None else f"ChatLearn-{model_name}" _logger = logging.getLogger(logger_name) _logger.handlers.clear() _logger.propagate = False if log_level is None: log_level = logging.INFO _logger.setLevel(log_level) handler = logging.StreamHandler() if ip_addr is None: handler.setFormatter(logging.Formatter( '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s')) else: rank = os.environ.get("RANK", 0) handler.setFormatter(logging.Formatter( f"[%(asctime)s %(name)s {ip_addr} RANK:{rank}] (%(filename)s %(lineno)d): %(levelname)s %(message)s")) handler.setLevel(log_level) _logger.addHandler(handler) return _logger logger = setup_logger() def log_rank_0(msg, custom_logger=None): _logger = custom_logger if custom_logger is not None else logger if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: _logger.info(msg) else: _logger.info(msg) def debug_rank_0(msg, custom_logger=None): _logger = custom_logger if custom_logger is not None else logger if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: _logger.debug(msg) else: _logger.debug(msg) class logging_tqdm(tqdm): """logging tqdm""" def __init__( self, *args, tqdm_logger=None, mininterval: float = 1, bar_format: str = '{desc}{percentage:3.0f}%{r_bar}', desc: str = 'progress: ', **kwargs): self._logger = tqdm_logger super().__init__( *args, mininterval=mininterval, bar_format=bar_format, desc=desc, **kwargs ) @property def logger(self): if self._logger is not None: return self._logger return logger def display(self, msg=None, pos=None): # pylint: disable=unused-argument if not self.n: # skip progress bar before having processed anything return if not msg: msg = f"{self}" self.logger.info('%s', msg)