def preprocess_compute()

in chatlearn/runtime/decorator.py [0:0]


def preprocess_compute(func, trainable):
    """
    1. if not trainable, merge a list of dict into one dict, i.e., merge inputs of forward_step.
    2. split a list of data for data_parallel, this is used for train_step
    3. convert output to cpu
    """

    def inner(self, *args, **kwargs):
        args = future.get(args)
        assert isinstance(args, (list, tuple)), f"expect args is a list, while {type(args)}, args: {args}."
        batched_data_list = [None] * len(args)
        if not trainable:
            self._logger.info(f"{LOG_START} start to merge data for {self.name} replica {self.replica_id}.")
            self._logger.info(f"{LOG_START} preprocess_compute model {self.name} replica {self.replica_id} \
                has inputs from {len(args)} input node.")

            for idx, arg_obj in enumerate(args):
                batched_data_list[idx] = arg_obj
                if CHATLEARN_REGROUP_TAG in arg_obj:
                    batched_data_list[idx] = regroup_by_concat_along_batch(arg_obj[CHATLEARN_REGROUP_TAG])
                if INDEX_TAG in arg_obj:
                    batched_data_list[idx] = slice_by_index_along_batch(batched_data_list[idx], arg_obj[INDEX_TAG])
                assert isinstance(batched_data_list[idx], dict), \
                    f"expect output arg for {self.name} to be a dict, while {type(batched_data_list[idx])}, arg: {batched_data_list[idx]}"
            if all(isinstance(batched_data, dict) for batched_data in batched_data_list):
                merged = {}
                for batched_data in batched_data_list:
                    merged.update(batched_data)
                args = [merged]

            self._logger.info(f"{LOG_START} complete to merge data for {self.name}.")

        def get_kwarg(key):
            return kwargs.pop(key) if key in kwargs else False
        to_empty_cache = get_kwarg('to_empty_cache')
        to_onload = get_kwarg('to_onload')
        to_offload = get_kwarg('to_offload')
        is_last_batch = get_kwarg('is_last_batch')
        is_eval = get_kwarg('is_eval')

        if to_onload:
            if isinstance(self, VLLMModuleV2):
                self.onload_for_workers()
            else:
                self.onload()
        generation_batch_size = self.module_args.generation_batch_size
        final_results = None
        if not trainable and generation_batch_size:
            # split into micro-batches if generation_batch_size < input_batch, then concat the results
            # this happens when different models have difference batch sizes
            input_batch = 0
            if len(args) > 0:
                for value in args[0].values():
                    input_batch = len(value)
                    break
                input_data = args[0]
            else:
                input_data = None
            if generation_batch_size != -1 and input_data is not None and input_batch > generation_batch_size and not hasattr(self, 'generate_vllm'):
                args = list(args)
                batches = split_along_batch(input_data, generation_batch_size)
                results = []
                for batch in batches:
                    args[0] = batch
                    if 'iteration' in inspect.signature(func).parameters:
                        kwargs["iteration"] = self._iteration
                    ret = func(self, *args, **kwargs)
                    self._iteration += 1
                    ret = utils.to_device('cpu', ret)
                    results.append(ret)
                # for model with DP/EP, we need to return results from all ranks
                # for model with TP/PP, only return the results from last rank
                if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
                        or isinstance(self, VLLMModuleV2):
                    final_results = concat_along_batch(results)
            else:
                if 'iteration' in inspect.signature(func).parameters:
                    kwargs["iteration"] = self._iteration
                ret = func(self, *args, **kwargs)
                ret = utils.to_device('cpu', ret)
                self._iteration += 1
                final_results = None
                # for model with DP/EP, we need to return results from all ranks
                # for model with TP/PP, only return the results from last rank
                if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
                        or isinstance(self, VLLMModuleV2):
                    final_results = ret
        else:
            if 'iteration' in inspect.signature(func).parameters:
                kwargs["iteration"] = self._train_iteration
            self._train_iteration += 1
            ret = func(self, *args, **kwargs)
            ret = utils.to_device('cpu', ret)
            if self.is_last_rank():
                final_results = ret
        if to_empty_cache:
            if isinstance(self, VLLMModuleV2):
                self.empty_cuda_graph_for_workers()
                self.empty_cache_for_workers()
            else:
                self.empty_cache()
        if to_offload:
            if isinstance(self, VLLMModuleV2):
                self.offload_for_workers()
            else:
                self.offload()
        if is_last_batch and not is_eval:
            self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode
        return final_results

    return inner