in chatlearn/synchronizer/parameter_sync.py [0:0]
def build_rank_mapping_two_stage(self, add_recv_actor_fn=None):
# setup rank mapping for src parameter and dst parameter
# get rank for one src_model, without model replicas
if add_recv_actor_fn is None:
add_recv_actor_stage1_fn = self.add_recv_actor
add_recv_actor_stage2_fn = self.add_recv_actor_stage2
else:
assert len(add_recv_actor_fn) == 2, (
"The length of add_recv_actor_fn should be 2. The first one is a function handler for communication stage 1, "
"while the second one is a function handler for communication stage 2."
)
add_recv_actor_stage1_fn = add_recv_actor_fn[0]
add_recv_actor_stage2_fn = add_recv_actor_fn[1]
src_ranks, dst_ranks = self.get_src_and_dst_dp_ranks(is_except_routed_experts=True)
if self._debug and (src_ranks[0] is None or dst_ranks is None):
return
replica_rank_iter = cycle(iter(src_ranks))
logger.debug(f"src_ranks: {src_ranks}")
logger.debug(f"dst_ranks: {dst_ranks}")
assert self.num_dst_tensor_parallel % self.num_src_tensor_parallel == 0, \
"currently we require mod value equals to zero for tensor_model_parallel_size of dst_model and that of src_model while " + \
f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \
f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})"
assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0
def split_ranks_by_tp_and_ep_size(ranks, tp_size, ep_size):
if ep_size > 1:
sort_ranks_on_grouped_tp = []
index = 0
tp_index = 0
for _ in range(len(ranks)):
sort_ranks_on_grouped_tp.append(index)
if tp_index < tp_size - 1:
index += 1
tp_index += 1
else:
start_index = index + 1 - tp_size
index = start_index + (ep_size * tp_size)
tp_index = 0
if index >= len(ranks):
index = (index + tp_size) % len(ranks)
else:
sort_ranks_on_grouped_tp = ranks
return [sort_ranks_on_grouped_tp[i:i + tp_size] for i in range(0, len(sort_ranks_on_grouped_tp), tp_size)]
pair_list = []
p2p_list = []
src_replica_offset = 0
lb_dst_offset_pq_dict = {}
for dst_replica_ranks in dst_ranks:
src_replica_ranks = next(replica_rank_iter)
# for weight except routed expert, ep_size using for data parallel.
src_replica_ranks_group = split_ranks_by_tp_and_ep_size(src_replica_ranks, self.num_src_tensor_parallel, 1)
dst_replica_ranks_group = split_ranks_by_tp_and_ep_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel)
logger.debug(f"src_replica_ranks_group: {src_replica_ranks_group}")
logger.debug(f"dst_replica_ranks_group: {dst_replica_ranks_group}")
pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage
assert pipe_map_interval >= 1, \
f"dst_pp expected to divide src_pp, while src_pp {self.num_src_pipeline_stage} and dst_pp {self.num_dst_pipeline_stage}"
# stage 1: comm pairs that broadcast params from trainer to inference model
# Each rank in trainer holds weights for tp_num_mapping ranks in inference model.
# For example: trainer_tp = 2, inference_tp = 4 => tp_num_mapping = inference_tp // trainer_tp = 2
# Weight mapping from training to inference:
# [0] -> [0', 1']
# [1] -> [2', 3']
# To avoid p2p communication on the same gpu, we only broadcast params to first rank in weight_mapping_group.
# Comm mapping from training to inference:
# [0] -> [0']
# [1] -> [2']
# Firstly, pre-allocate for those gpu collisions
uncollided_index_to_start_j = {}
for i, src_tp_group in enumerate(src_replica_ranks_group):
if i < src_replica_offset:
continue
j = (i - src_replica_offset) // pipe_map_interval
if j == self.num_dst_pipeline_stage:
src_replica_offset = i
break
if self.tp_num_mapping == 1:
start = 0
else:
mod_i = (i - src_replica_offset) % self.tp_num_mapping
start = mod_i if (i - src_replica_offset) < self.tp_num_mapping else (self.tp_num_mapping - mod_i - 1) % self.tp_num_mapping
for s_idx, src_rank in enumerate(src_tp_group):
dst_rank, is_collide = self.get_load_balance_dst_rank(
lb_dst_offset_pq_dict,
s_idx,
start,
src_rank,
dst_replica_ranks_group,
j,
pre_allocate=True
)
if is_collide:
add_recv_actor_stage1_fn(src_rank, dst_rank)
pair_list.append((src_rank, dst_rank))
else:
assert dst_rank is None
uncollided_index_to_start_j.update({(i, s_idx) : (start, j)})
# Then, allocate src_ranks without gpu collisions
for i, src_tp_group in enumerate(src_replica_ranks_group):
for s_idx, src_rank in enumerate(src_tp_group):
if (i, s_idx) not in uncollided_index_to_start_j:
continue
start, j = uncollided_index_to_start_j.get((i, s_idx))
dst_rank, _ = self.get_load_balance_dst_rank(
lb_dst_offset_pq_dict,
s_idx,
start,
src_rank,
dst_replica_ranks_group,
j,
pre_allocate=False
)
add_recv_actor_stage1_fn(src_rank, dst_rank)
pair_list.append((src_rank, dst_rank))
# stage 2: comm pairs that broadcast params from first rank to the other ranks for each weight_mapping_group
# Comm mapping in each weight_mapping_group of inference:
# [0'] -> [1']
# [2'] -> [3']
recv_ranks = [pair[1] for pair in pair_list]
def p2p_pair_grouping(tuples):
for s_idx, src_rank in enumerate(tuples):
for d_idx, dst_rank in enumerate(tuples):
if s_idx == d_idx or src_rank not in recv_ranks: # pylint: disable=cell-var-from-loop
continue
add_recv_actor_stage2_fn(src_rank, dst_rank)
p2p_list.append((src_rank, dst_rank))
for dst_tp_group in dst_replica_ranks_group:
dst_tp_group = split_ranks_by_tp_and_ep_size(dst_tp_group, self.tp_num_mapping, 1)
for tuples in dst_tp_group:
p2p_pair_grouping(tuples)
logger.info(f"comm pair_list <train_rank, inference_rank>: {pair_list}")
logger.info(f"comm p2p_list <inference_rank, inference_rank>: {p2p_list}")