in step6_data_parallel_bucket/tensor_parallel.py [0:0]
def apply_tensor_parallel(model):
def _replace_module(_module, _linear_proj_name, _style, args={}):
assert _style in ["column", "row", 'vocab']
linear_layer = getattr(_module, _linear_proj_name)
if _style == "column":
new_linear_layer = ColumnParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
gather_output=args.get("gather_output", False)
)
elif _style == "row":
new_linear_layer = RowParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
)
else:
new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim,
)
setattr(_module, _linear_proj_name, new_linear_layer)
module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
("attention", "v_proj", "column"),
("attention", "out_proj", "row"),
("mlp", "up_proj", "column"),
("mlp", "gate_proj", "column"),
("mlp", "down_proj", "row"),
]
for layer in model.decoder_layers:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
_replace_module(getattr(layer, module_name), linear_proj_name, style)
_replace_module(model, "embedding", "vocab")
_replace_module(model, "final_proj", "column", args={"gather_output": True})
return model