chatlearn/runtime/decorator.py (187 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """module runtime decorator""" import inspect import traceback import torch from torch.cuda import nvtx import ray from chatlearn.models.vllm_module_v2 import VLLMModuleV2 from chatlearn.utils import future from chatlearn.utils import utils from chatlearn.utils.constant import CHATLEARN_REGROUP_TAG, INDEX_TAG from chatlearn.utils.constant import LOG_START from chatlearn.utils.global_vars import _EXIT_ACTOR_NAME, set_wrap_func from chatlearn.utils.utils import execute from chatlearn.utils.utils import regroup_by_concat_along_batch, slice_by_index_along_batch def monitor_error(func, func_name): def inner(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception as e: self._logger.exception(f"Catch exception ========= in {self.name} {func_name} {e}") exit_actor = ray.get_actor(_EXIT_ACTOR_NAME) traceback_msg = f"{traceback.format_exc()}" address = self.get_address() ray.get(exit_actor.add_error_node_and_msg.remote(address, traceback_msg)) future.wait(self.error_signal.set_address.remote(address)) # for other error, we raise in the corresponding workers if self.is_master_node(): for line in traceback_msg.split("\n"): self._logger.exception(line) execute("ray stop") raise return inner def timeit(func, func_name): def inner(self, *args, **kwargs): if self.runtime_args.nsys: nvtx.range_push(func_name) if self.is_last_rank(): # for the class inherited from base, it may call multiple times, so use the first start time if not self.timers(func_name).started_: self.timers(func_name).start() ret = func(self, *args, **kwargs) self.timers(func_name).stop() else: ret = func(self, *args, **kwargs) if self.profiler is not None and self._iteration > 0 and self._iteration <=2 and self.replica_id == 0 \ and func_name in ["forward_step", "train_step"]: self.profiler.step() if self.profiler is not None and self._iteration ==3 and self.replica_id == 0 and func_name in ["forward_step", "train_step"]: self.profiler.stop() if self.runtime_args.nsys: nvtx.range_pop() return ret return inner def split_along_batch(batch, new_batch_size): assert isinstance(batch, (list, tuple, dict)), \ "batch type {} is not supported".format(type(batch)) if isinstance(batch, (list, tuple)): bs = len(batch[0]) keys = range(len(batch)) else: bs = len(next(iter(batch.values()))) keys = batch.keys() accum_bs = 0 new_batches = [] while accum_bs < bs: if isinstance(batch, (list, tuple)): new_batch = [batch[key][accum_bs:min(accum_bs + new_batch_size, bs)] for key in keys] else: new_batch = {key: batch[key][accum_bs:min(accum_bs + new_batch_size, bs)] for key in keys} accum_bs += new_batch_size new_batches.append(new_batch) return new_batches def concat_along_batch(tensors): batched = {} if tensors[0] is None: return batched for key in tensors[0].keys(): to_batch = [results[key] for results in tensors] if isinstance(to_batch[0], torch.Tensor): batched[key] = torch.concat(to_batch) elif isinstance(to_batch[0], list): batched[key] = [] for seq in to_batch: batched[key].extend(seq) else: raise Exception(f"unknown types key: {key} and {type(to_batch[0])} to concat") return batched def preprocess_compute(func, trainable): """ 1. if not trainable, merge a list of dict into one dict, i.e., merge inputs of forward_step. 2. split a list of data for data_parallel, this is used for train_step 3. convert output to cpu """ def inner(self, *args, **kwargs): args = future.get(args) assert isinstance(args, (list, tuple)), f"expect args is a list, while {type(args)}, args: {args}." batched_data_list = [None] * len(args) if not trainable: self._logger.info(f"{LOG_START} start to merge data for {self.name} replica {self.replica_id}.") self._logger.info(f"{LOG_START} preprocess_compute model {self.name} replica {self.replica_id} \ has inputs from {len(args)} input node.") for idx, arg_obj in enumerate(args): batched_data_list[idx] = arg_obj if CHATLEARN_REGROUP_TAG in arg_obj: batched_data_list[idx] = regroup_by_concat_along_batch(arg_obj[CHATLEARN_REGROUP_TAG]) if INDEX_TAG in arg_obj: batched_data_list[idx] = slice_by_index_along_batch(batched_data_list[idx], arg_obj[INDEX_TAG]) assert isinstance(batched_data_list[idx], dict), \ f"expect output arg for {self.name} to be a dict, while {type(batched_data_list[idx])}, arg: {batched_data_list[idx]}" if all(isinstance(batched_data, dict) for batched_data in batched_data_list): merged = {} for batched_data in batched_data_list: merged.update(batched_data) args = [merged] self._logger.info(f"{LOG_START} complete to merge data for {self.name}.") def get_kwarg(key): return kwargs.pop(key) if key in kwargs else False to_empty_cache = get_kwarg('to_empty_cache') to_onload = get_kwarg('to_onload') to_offload = get_kwarg('to_offload') is_last_batch = get_kwarg('is_last_batch') is_eval = get_kwarg('is_eval') if to_onload: if isinstance(self, VLLMModuleV2): self.onload_for_workers() else: self.onload() generation_batch_size = self.module_args.generation_batch_size final_results = None if not trainable and generation_batch_size: # split into micro-batches if generation_batch_size < input_batch, then concat the results # this happens when different models have difference batch sizes input_batch = 0 if len(args) > 0: for value in args[0].values(): input_batch = len(value) break input_data = args[0] else: input_data = None if generation_batch_size != -1 and input_data is not None and input_batch > generation_batch_size and not hasattr(self, 'generate_vllm'): args = list(args) batches = split_along_batch(input_data, generation_batch_size) results = [] for batch in batches: args[0] = batch if 'iteration' in inspect.signature(func).parameters: kwargs["iteration"] = self._iteration ret = func(self, *args, **kwargs) self._iteration += 1 ret = utils.to_device('cpu', ret) results.append(ret) # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ or isinstance(self, VLLMModuleV2): final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: kwargs["iteration"] = self._iteration ret = func(self, *args, **kwargs) ret = utils.to_device('cpu', ret) self._iteration += 1 final_results = None # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ or isinstance(self, VLLMModuleV2): final_results = ret else: if 'iteration' in inspect.signature(func).parameters: kwargs["iteration"] = self._train_iteration self._train_iteration += 1 ret = func(self, *args, **kwargs) ret = utils.to_device('cpu', ret) if self.is_last_rank(): final_results = ret if to_empty_cache: if isinstance(self, VLLMModuleV2): self.empty_cuda_graph_for_workers() self.empty_cache_for_workers() else: self.empty_cache() if to_offload: if isinstance(self, VLLMModuleV2): self.offload_for_workers() else: self.offload() if is_last_batch and not is_eval: self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode return final_results return inner def decorate_class_func(cls, func_name, decorator, *args, **kwargs): if not hasattr(cls, func_name): return func = getattr(cls, func_name) if func.__qualname__.startswith(decorator.__name__): # already decorated # This usually occurs when one class inherits from another class, # for example, if 'reference' inherits from 'policy', then methods like 'offload_optimizer_states' # would be decorated in the base class, eliminating the need for repeated decoration. return new_func = decorator(func, *args, **kwargs) set_wrap_func(func, new_func) setattr(cls, func_name, new_func)