utils/benchmark_logger.py (64 lines of code) (raw):

""" Copyright (c) 2021, Alibaba Group; Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import sys import torch import logging from utils.timer import Timer from log_analyzer.log import Log, LogItem class LoggerFactory: @staticmethod def create_logger(name=None, level=logging.INFO): """create a logger Args: name (str): name of the logger level: level of logger Raises: ValueError is name is None """ if name is None: raise ValueError("name for logger cannot be None") formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s") logger_ = logging.getLogger(name) logger_.setLevel(level) logger_.propagate = False ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(level) ch.setFormatter(formatter) logger_.addHandler(ch) return logger_ logger = LoggerFactory.create_logger(name="LLM_Comm_Benchmark", level=logging.INFO) class BenchLogger: def __init__(self): self.comm_log = Log() self.enable = True self.timer = Timer() self.epoch_timer = Timer(use_host_timer=True) self.epoch = 0 self.epoch_timer.start() def log_timing(self, name): def decorator(func): def wrapper(*args, **kwargs): self.timer.start() result = func(*args, **kwargs) elapsed_time_ms = self.timer.stop() log_item = next((item for item in args if isinstance(item, LogItem))) if log_item.additional == 'overlap': log_item.elapsed_time = 0 else: log_item.elapsed_time = elapsed_time_ms self.comm_log.add_comm_log(log_item) if torch.distributed.get_rank() == 0: logger.info(log_item.view_as_ds_log()) return result return wrapper return decorator def end_epoch(self, log_item): torch.cuda.synchronize() elapsed_time_ms = self.epoch_timer.stop() if torch.distributed.get_rank() == 0: logger.info( f"[RANK 0] --------epoch {self.epoch} | micro_step time {elapsed_time_ms:.2f} ---------\n" ) log_item.elapsed_time = elapsed_time_ms self.comm_log.add_comm_log(log_item) self.epoch += 1 self.epoch_timer.start() def dump_log(self, filename): csv_filename = self.comm_log.dump(filename) return csv_filename def analyze_comm_log(self, print_fn=logger.info): return self.comm_log.analyze(print_fn) def analyze_comm_time(self, print_fn=logger.info): return self.comm_log.analyze_time(print_fn) bench_logger = BenchLogger()