in chatlearn/synchronizer/megatron_vllm.py [0:0]
def regroup_qkv_tp_slices(self, name, param_data, tp_division):
param_data_shape = param_data.shape
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
# pylint: disable=too-many-nested-blocks
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // tp_division
for idx in range(tp_division):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
if self.src_module_args.args_dict["group_query_attention"]:
num_query_groups = self.src_module_args.args_dict["num_query_groups"]
assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], (
f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
)
src_num_query_groups_per_replica = num_query_groups // src_tp_size
if dst_tp_size >= num_query_groups:
num_dst_kv_head_replicas = dst_tp_size // num_query_groups
else:
num_dst_kv_head_replicas = 1
else:
src_num_query_groups_per_replica = heads
num_dst_kv_head_replicas = 1
if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
if len(param_data_shape) == 1:
param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
else:
param_data = param_data.view(
(heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // tp_division
for idx in range(tp_division):
q_start = idx * head_offset
q_end = q_start + head_offset
if num_dst_kv_head_replicas == 1:
if src_num_query_groups_per_replica > tp_division:
assert src_num_query_groups_per_replica % tp_division == 0, (
f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) "
f"must be divisible by tp_division ({tp_division}). Please double-check your config."
)
kv_offset = src_num_query_groups_per_replica // tp_division
else:
kv_offset = 1
k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads
k_end = k_start + kv_offset
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + kv_offset
else:
k_start = heads + idx // num_dst_kv_head_replicas
k_end = k_start + 1
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + 1
q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)
if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.src_module_args.args_dict["hidden_size"]).contiguous()
param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list
return param_data