in chatlearn/runtime/decorator.py [0:0]
def preprocess_compute(func, trainable):
"""
1. if not trainable, merge a list of dict into one dict, i.e., merge inputs of forward_step.
2. split a list of data for data_parallel, this is used for train_step
3. convert output to cpu
"""
def inner(self, *args, **kwargs):
args = future.get(args)
assert isinstance(args, (list, tuple)), f"expect args is a list, while {type(args)}, args: {args}."
batched_data_list = [None] * len(args)
if not trainable:
self._logger.info(f"{LOG_START} start to merge data for {self.name} replica {self.replica_id}.")
self._logger.info(f"{LOG_START} preprocess_compute model {self.name} replica {self.replica_id} \
has inputs from {len(args)} input node.")
for idx, arg_obj in enumerate(args):
batched_data_list[idx] = arg_obj
if CHATLEARN_REGROUP_TAG in arg_obj:
batched_data_list[idx] = regroup_by_concat_along_batch(arg_obj[CHATLEARN_REGROUP_TAG])
if INDEX_TAG in arg_obj:
batched_data_list[idx] = slice_by_index_along_batch(batched_data_list[idx], arg_obj[INDEX_TAG])
assert isinstance(batched_data_list[idx], dict), \
f"expect output arg for {self.name} to be a dict, while {type(batched_data_list[idx])}, arg: {batched_data_list[idx]}"
if all(isinstance(batched_data, dict) for batched_data in batched_data_list):
merged = {}
for batched_data in batched_data_list:
merged.update(batched_data)
args = [merged]
self._logger.info(f"{LOG_START} complete to merge data for {self.name}.")
def get_kwarg(key):
return kwargs.pop(key) if key in kwargs else False
to_empty_cache = get_kwarg('to_empty_cache')
to_onload = get_kwarg('to_onload')
to_offload = get_kwarg('to_offload')
is_last_batch = get_kwarg('is_last_batch')
is_eval = get_kwarg('is_eval')
if to_onload:
if isinstance(self, VLLMModuleV2):
self.onload_for_workers()
else:
self.onload()
generation_batch_size = self.module_args.generation_batch_size
final_results = None
if not trainable and generation_batch_size:
# split into micro-batches if generation_batch_size < input_batch, then concat the results
# this happens when different models have difference batch sizes
input_batch = 0
if len(args) > 0:
for value in args[0].values():
input_batch = len(value)
break
input_data = args[0]
else:
input_data = None
if generation_batch_size != -1 and input_data is not None and input_batch > generation_batch_size and not hasattr(self, 'generate_vllm'):
args = list(args)
batches = split_along_batch(input_data, generation_batch_size)
results = []
for batch in batches:
args[0] = batch
if 'iteration' in inspect.signature(func).parameters:
kwargs["iteration"] = self._iteration
ret = func(self, *args, **kwargs)
self._iteration += 1
ret = utils.to_device('cpu', ret)
results.append(ret)
# for model with DP/EP, we need to return results from all ranks
# for model with TP/PP, only return the results from last rank
if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
or isinstance(self, VLLMModuleV2):
final_results = concat_along_batch(results)
else:
if 'iteration' in inspect.signature(func).parameters:
kwargs["iteration"] = self._iteration
ret = func(self, *args, **kwargs)
ret = utils.to_device('cpu', ret)
self._iteration += 1
final_results = None
# for model with DP/EP, we need to return results from all ranks
# for model with TP/PP, only return the results from last rank
if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
or isinstance(self, VLLMModuleV2):
final_results = ret
else:
if 'iteration' in inspect.signature(func).parameters:
kwargs["iteration"] = self._train_iteration
self._train_iteration += 1
ret = func(self, *args, **kwargs)
ret = utils.to_device('cpu', ret)
if self.is_last_rank():
final_results = ret
if to_empty_cache:
if isinstance(self, VLLMModuleV2):
self.empty_cuda_graph_for_workers()
self.empty_cache_for_workers()
else:
self.empty_cache()
if to_offload:
if isinstance(self, VLLMModuleV2):
self.offload_for_workers()
else:
self.offload()
if is_last_batch and not is_eval:
self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode
return final_results
return inner