in chatlearn/synchronizer/parameter_sync.py [0:0]
def get_src_and_dst_dp_ranks(self, is_except_routed_experts=False):
"""
Return:
The DP Group List for src & dst model [[DP-0], [DP-1] ... [DP-N]]
"""
dst_dp_ranks = self.dst_model.all_ranks
local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks())
if local_src_ranks[0] is None or dst_dp_ranks is None:
if self._debug:
logger.warning(
f"DEBUG MODE! src_dp_ranks {local_src_ranks} or dst_dp_ranks: {dst_dp_ranks} is None, "
"make sure they have values in real application.")
return local_src_ranks, dst_dp_ranks
else:
raise Exception(f"src_dp_ranks {local_src_ranks} or dst_dp_ranks {dst_dp_ranks} should not be None")
dp_rank_to_ranks = defaultdict(list)
for local_ranks, dp_rank in local_src_ranks:
dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank])
if is_except_routed_experts:
# for weight except routed expert, ep_size using for data parallel.
# TODO-1 The logic here is a little bit complicate, it would be better to move to a seperate function
# TODO-2 The logic here is about HEP, would be better called from class ParameterSyncGroupwithHEP
src_hep_size = self.num_src_expert_parallel * self.num_src_tensor_parallel
new_dict = defaultdict(list)
idx = 0
for dp_rank, values in dp_rank_to_ranks.items():
assert len(values) % src_hep_size == 0, (
f"len of values({len(values)}) for dp_rank {dp_rank} must be divisible by hep size({src_hep_size})"
f" when call get_src_and_dst_dp_ranks_for_except_routed_experts."
)
pp_blocks = [values[i:i + src_hep_size] for i in range(0, len(values), src_hep_size)]
sub_blocks_per_pp = []
for block in pp_blocks:
sub_block_size = src_hep_size // self.num_src_expert_parallel
sub_blocks = [block[i:i + sub_block_size] for i in range(0, src_hep_size, sub_block_size)]
sub_blocks_per_pp.append(sub_blocks)
for i in range(self.num_src_expert_parallel):
merged_group = []
for sub_blocks in sub_blocks_per_pp:
merged_group.extend(sub_blocks[i])
new_dict[idx].extend(merged_group)
idx += 1
src_dp_ranks = [i[1] for i in sorted(new_dict.items())]
else:
src_dp_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())]
return src_dp_ranks, dst_dp_ranks