in chatlearn/synchronizer/base.py [0:0]
def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False):
"""
:meta private:
"""
param_data_shape = param_data.shape
# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# 'mlp.shared_experts.dense_4h_to_h in QWen-MoE model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
# 'mlp.shared_experts.dense_h_to_4h in QWen-MoE model
if (
"self_attention.dense" in name
or "mlp.dense_4h_to_h" in name
or "mlp.linear_fc2" in name
or "mlp.shared_experts.dense_4h_to_h" in name
):
param_data_list = []
col_offset = param_data_shape[1] // tp_division
for idx in range(tp_division):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
if (
"mlp.dense_h_to_4h" in name
or "mlp.linear_fc1" in name
or ("mlp.w1" in name and self.concat_params_dict is not None)
or "mlp.shared_experts.dense_h_to_4h" in name
):
param_data_list = []
row_offset = param_data_shape[0] // tp_division // 2
for idx in range(tp_division):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + tp_division) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
# src -> dst: src_tp_size * [e, h, w] -> dst_tp_size * [e, h, w // tp_division]
# 'mlp.experts.dense_4h_to_h' in QWen-MoE model when training with QWen+vLLM
# src -> dst: src_tp_size * [e, w, h] -> dst_tp_size * [e, w // tp_division, h]
# 'mlp.experts.dense_h_to_4h in QWen-MoE model when training with QWen+vLLM
if regroup_routed_experts:
if "mlp.experts.dense_4h_to_h" in name:
param_data_list = []
height_offset = param_data_shape[2] // tp_division
for height_idx in range(tp_division):
height_start = height_idx * height_offset
height_end = height_start + height_offset
param_data_list.append(param_data[:, :, height_start:height_end])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
elif "mlp.experts.dense_h_to_4h" in name:
param_data_list = []
param_data = param_data.reshape(param_data_shape[0] * 2, -1, param_data_shape[2])
col_offset = param_data.shape[1] // tp_division
for idx in range(tp_division):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:, start:end, :])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
return param_data