def regroup_qkv_tp_slices()

in chatlearn/synchronizer/megatron_vllm.py [0:0]


    def regroup_qkv_tp_slices(self, name, param_data, tp_division):
        param_data_shape = param_data.shape
        # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
        to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
        # pylint: disable=too-many-nested-blocks
        if "attention.query_key_value" in name or \
                "self_attention.query_key_value" in name or \
                "self_attention.linear_qkv" in name:
            src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
            heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
            hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]

            param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
            division = reduce(operator.mul, param_shape, 1)
            num_elements = param_data.numel()
            if num_elements == division:
                if to_fix_qkv_ordering_dict is not None:
                    param_data = param_data.view(param_shape)
                    param_data_list = []
                    head_offset = heads // tp_division
                    for idx in range(tp_division):
                        start = idx * head_offset
                        end = start + head_offset
                        param_data_list.append(param_data[:,start:end])
                    param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
                    del param_data_list
            else:
                if self.src_module_args.args_dict["group_query_attention"]:
                    num_query_groups = self.src_module_args.args_dict["num_query_groups"]
                    assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], (
                        f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
                        f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
                    )
                    src_num_query_groups_per_replica = num_query_groups // src_tp_size
                    if dst_tp_size >= num_query_groups:
                        num_dst_kv_head_replicas = dst_tp_size // num_query_groups
                    else:
                        num_dst_kv_head_replicas = 1
                else:
                    src_num_query_groups_per_replica = heads
                    num_dst_kv_head_replicas = 1

                if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
                    if len(param_data_shape) == 1:
                        param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
                    else:
                        param_data = param_data.view(
                            (heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
                    param_data_list = []
                    head_offset = heads // tp_division
                    for idx in range(tp_division):
                        q_start = idx * head_offset
                        q_end = q_start + head_offset
                        if num_dst_kv_head_replicas == 1:
                            if src_num_query_groups_per_replica > tp_division:
                                assert src_num_query_groups_per_replica % tp_division == 0, (
                                    f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) "
                                    f"must be divisible by tp_division ({tp_division}). Please double-check your config."
                                )
                                kv_offset = src_num_query_groups_per_replica // tp_division
                            else:
                                kv_offset = 1
                            k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads
                            k_end = k_start + kv_offset
                            v_start = k_start + src_num_query_groups_per_replica
                            v_end = v_start + kv_offset
                        else:
                            k_start = heads + idx // num_dst_kv_head_replicas
                            k_end = k_start + 1
                            v_start = k_start + src_num_query_groups_per_replica
                            v_end = v_start + 1

                        q_proj = param_data[q_start:q_end].contiguous()
                        k_proj = param_data[k_start:k_end].contiguous()
                        v_proj = param_data[v_start:v_end].contiguous()

                        qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

                        if len(param_data_shape) == 1:
                            qkv_proj = qkv_proj.reshape(-1).contiguous()
                        else:
                            qkv_proj = qkv_proj.reshape(-1, self.src_module_args.args_dict["hidden_size"]).contiguous()

                        param_data_list.append(qkv_proj)
                    param_data = torch.concat(param_data_list, dim=0)
                    del param_data_list
        return param_data