in chatlearn/runtime/executor.py [0:0]
def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step",
to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):
"""
Args:
model: DistModel
in_queue: Queue
out_queue: Queue
step_num: int
func_name: str
to_empty_cache: None or boolean
"""
model = model_node.model
# output is a list of tuple, each tuple is (remote_refs, mb)
output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache,
is_eval, to_onload, to_offload, micro_batch_index)
num_dp_rank = len(replica.dp_rank_to_actors)
if model.module_args.zero_size == 1:
# If (tp > 1 or pp > 1) and ep = 1 for current model, its `output` will be a list whose
# length is the number of Actors. In this case, all members in the list
# are the same, and we choose output[-1] to put into out_queue.
# If (tp > 1 or pp > 1) and ep > 1, we choose last output for each dp rank to put into
# out_queue.
if model.module_args.expert_model_parallel_size == 1 and num_dp_rank == 1:
result = [output[-1]]
else:
num_output = len(output)
assert num_output % num_dp_rank == 0, (
f"The number of outputs ({num_output}) must be divisible by "
f"the number of dp_ranks ({num_dp_rank}) in a replica."
)
interval = num_output // num_dp_rank
result = [output[i] for i in range(interval - 1, num_output, interval)]
else:
result = output
if isinstance(out_queue, list):
for oq in out_queue:
for res, mb in result:
oq.put(encode_data(mb, res))
else:
for res, mb in result:
out_queue.put(encode_data(mb, res))
# To ensure all Actors are finished synchronously, all remote refs should be returned
# note that ray wait does not support tuple type, return a list of list
remote_refs = [item[0] for item in output]
return out_queue, remote_refs