def generate_step_one_model()

in chatlearn/runtime/executor.py [0:0]


    def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step",
                                to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):
        """
        Args:
            model: DistModel
            in_queue: Queue
            out_queue: Queue
            step_num: int
            func_name: str
            to_empty_cache: None or boolean
        """
        model = model_node.model
        # output is a list of tuple, each tuple is (remote_refs, mb)
        output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache,
                                                       is_eval, to_onload, to_offload, micro_batch_index)

        num_dp_rank = len(replica.dp_rank_to_actors)
        if model.module_args.zero_size == 1:
            # If (tp > 1 or pp > 1) and ep = 1 for current model, its `output` will be a list whose
            #   length is the number of Actors. In this case, all members in the list
            #   are the same, and we choose output[-1] to put into out_queue.
            # If (tp > 1 or pp > 1) and ep > 1, we choose last output for each dp rank to put into
            #   out_queue.
            if model.module_args.expert_model_parallel_size == 1 and num_dp_rank == 1:
                result = [output[-1]]
            else:
                num_output = len(output)
                assert num_output % num_dp_rank == 0, (
                    f"The number of outputs ({num_output}) must be divisible by "
                    f"the number of dp_ranks ({num_dp_rank}) in a replica."
                )
                interval = num_output // num_dp_rank
                result = [output[i] for i in range(interval - 1, num_output, interval)]
        else:
            result = output
        if isinstance(out_queue, list):
            for oq in out_queue:
                for res, mb in result:
                    oq.put(encode_data(mb, res))
        else:
            for res, mb in result:
                out_queue.put(encode_data(mb, res))
        # To ensure all Actors are finished synchronously, all remote refs should be returned
        # note that ray wait does not support tuple type, return a list of list
        remote_refs = [item[0] for item in output]
        return out_queue, remote_refs