chatlearn/utils/timer.py (82 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Timer"""
import time
import torch
from .logger import logger
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
self._num = 0
self._cuda_available = torch.cuda.is_available()
def cuda_sync(self):
if self._cuda_available:
torch.cuda.synchronize()
def start(self):
"""Start the timer."""
assert not self.started_, f'timer {self.name_} has already been started'
self.cuda_sync()
self.start_time = time.time()
self.started_ = True
self._num += 1
def stop(self):
"""Stop the timer."""
self.cuda_sync()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
self._num = 0
def elapsed(self, reset=True, return_num=False):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
num_ = self._num
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
if return_num:
return elapsed_, num_
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names=None, normalizer=1.0, reset=True, return_dict=False, e2e_cost=None, skip_zero=True):
"""Log a group of timers."""
all_keys = self.timers.keys()
name2log = {}
assert normalizer > 0.0
string = 'time (min)'
if e2e_cost is not None:
string += ' | e2e_cost: {:.2f}'.format(e2e_cost)
for name in all_keys:
if name not in self.timers:
logger.warning(f"{name} not in timers, ignore it.")
continue
if names is not None and name not in names:
self.timers[name].reset()
continue
elapsed_time, num = self.timers[name].elapsed(reset=reset, return_num=True)
if skip_zero and elapsed_time < 1e-6: # less than 1 us, we attribute it as not executed.
continue
elapsed_time = elapsed_time * 1.0 / 60 / normalizer
if num >= 1:
avg_elapsed_time = elapsed_time / num
string += ' | {}: {:.2f}(avg: {:.2f})'.format(name, elapsed_time, avg_elapsed_time)
else:
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if return_dict:
name2log[name] = elapsed_time
if return_dict:
return string, name2log
else:
return string