def broadcast_parameter_two_stage()

in chatlearn/models/base_module.py [0:0]


    def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
        """
        Arguments:
            to_rank: receive rank in mapping from trainer to inference model.
            buffer_rank: index which tensors of sync buffer to be sended in stage2.
            rank: destination rank in communication group which enumerate receive ranks.
            src_rank: source rank in communication group. always 0.
            group_name: communication group name.
            pipe_stage: pipeline stage. default 0.
            stage2: bool. whether stage2 or not. default False.
        Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1
            stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)]
            stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)]

            For stage1 pair (0, 8):
                1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0.
                2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1.

                After (0, 8), to_rank 8 received tensor slices of 8 and 9.

            For stage2 pair (8, 9):
                1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0.
                2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1.
                In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer.
        """
        tensor_changed = rank != src_rank
        start = time.time()
        arguments = f"{to_rank}_{buffer_rank}_{rank}_{src_rank}_{group_name}_{pipe_stage}_{stage2}"

        if stage2:
            if tensor_changed:
                parameters_to_sync = self._parameters_to_recv[to_rank]
            else:
                parameters_to_sync = self._parameters_to_send
        else:
            if rank not in self._sync_dst_rank_to_src_ranks:
                self._sync_dst_rank_to_src_ranks.update({rank:[src_rank]})
                del self._sync_buffer
                self._sync_buffer = defaultdict(list)
            else:
                self._sync_dst_rank_to_src_ranks[rank].append(src_rank)
            parameters_to_sync = self._parameters_to_sync

        def tensor_generator():
            if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks
                idx = 0
                for name, param in parameters_to_sync[pipe_stage]:
                    value = self._sync_buffer[buffer_rank % self.tp_num_mapping][idx].cuda() # restore from cpu
                    self._logger.debug(
                        f"Adding {name}({value.shape}) to sync for if branch from "
                        f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
                    )
                    buffer_num = 1
                    idx += 1
                    yield value, buffer_num
                del self._sync_buffer[buffer_rank % self.tp_num_mapping]
            else:
                idx = 0
                for name, param in parameters_to_sync[pipe_stage]:
                    idx += 1
                    param_data = param.data
                    if rank and self._buffer_num and not stage2:
                        assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
                        buffer_num = self._buffer_num[name]
                    elif stage2:
                        buffer_num = 1
                    else:
                        if self._expert_sync_buffer and name in self._expert_sync_buffer:
                            param_data = self._expert_sync_buffer[name]
                            regroup_routed_experts = True # For routed experts in Qwen2vLLM
                        else:
                            regroup_routed_experts = False
                        # regroup src_tensor by tp_rank
                        param_data = self._synchronizer.regroup_params_to_sync(
                            name,
                            param_data,
                            self._tp_division[name],
                            regroup_routed_experts
                        )
                        # move self._expert_sync_buffer[name] to cpu mem to save gpu mem
                        if regroup_routed_experts and name in self._expert_sync_buffer:
                            cpu_expert = self._expert_sync_buffer[name].cpu()
                            del self._expert_sync_buffer[name]
                            self._expert_sync_buffer[name] = cpu_expert
                        buffer_num = 1
                    self._logger.debug(
                        f"Adding {name}({param_data.shape}) to sync for else branch from "
                        f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
                    )
                    yield param_data, buffer_num

        bucket_generator = bucket_tensors_two_stage_generator(
            tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
            stage2=stage2, tensor_changed=tensor_changed and not stage2
        )
        dense_bucket_num = 0
        sparse_bucket_num = 0
        for bucket_or_tensor, is_dense in bucket_generator:
            if is_dense:
                index = 0 if stage2 else (to_rank % self.tp_num_mapping)
                all_buffers = coalesced_comm_dense_two_stage(
                    bucket_or_tensor, col.broadcast, rank,
                    extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
                    stage2=stage2, index=index)
                if tensor_changed and not stage2:
                    for key, value in all_buffers.items():
                        cpu_value = []
                        for tensor in value:
                            cpu_value.append(tensor.cpu().pin_memory()) # save gpu memory
                        del value
                        self._sync_buffer[key] += cpu_value
                    del all_buffers
                dense_bucket_num += 1
            else:
                col.broadcast(bucket_or_tensor, src_rank, group_name)
                sparse_bucket_num += 1

        if stage2:
            self._sync_dst_rank_to_src_ranks = {}

        self._logger.debug(f"broadcast_parameter_two_stage {arguments} done using {time.time()-start} seconds")
        debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger)