in orttraining/orttraining/core/optimizer/megatron_transformer.cc [897:1268]
Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified,
std::vector<Node*>& nodes_to_clear_shape,
std::unordered_set<Node*>& dropout_nodes_to_transform,
int32_t& counter,
NodeIndex node_index) const {
auto skip_status = common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Skip BART Attention megatron transformation");
// Self/Enc-Dec Attention sub-graph.
//
// MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout
// MatMul->Add->Reshape->Transpose-------> | |
// MatMul->Add->Reshape->Transpose----------------------------------------------------------> |
auto& node = *graph.GetNode(node_index);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
node.GetOutputEdgesCount() != 1) {
return skip_status;
}
Node* q_matmul_input_node_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) {
return skip_status;
}
std::vector<Node*> sub_graph_node_ptrs;
sub_graph_node_ptrs.push_back(&node);
ProviderType provider_type = node.GetExecutionProviderType();
std::vector<NodeInfo> linear_pattern = {
NodeInfo({add_info}),
NodeInfo({mul_info}),
NodeInfo({reshape_info}),
NodeInfo({transpose_info}),
NodeInfo({matmul_info}),
NodeInfo({add_info}, false), // -13
NodeInfo({reshape_info}),
NodeInfo({where_info}),
NodeInfo({reshape_info}),
NodeInfo({softmax_info}),
NodeInfo({dropout_info}, false), // -8
NodeInfo({matmul_info}),
NodeInfo({add_info}, false), // -6
NodeInfo({transpose_info}),
NodeInfo({reshape_info}),
NodeInfo({matmul_info}), // -3
NodeInfo({add_info}),
NodeInfo({dropout_info}, false)}; // -1
if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) {
return skip_status;
}
// Get all useful nodes here as more vector push back below will change the index.
// Other than the optional nodes in the pattern, all other node pointers are valid
// if they match the linear pattern.
Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18];
Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8];
Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7];
Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3];
// Transpose node attribute checking.
if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) ||
!optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) {
return skip_status;
}
// map between reshape node and dim of reshape that must be modified
std::unordered_map<Node*, int64_t> reshape_node_ptrs;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0;
reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2;
// till now node should be q matmul operation
std::vector<Node*> weight_transpose_node_ptrs;
std::vector<Node*> bias_add_node_ptrs;
Node* q_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) {
return skip_status;
}
weight_transpose_node_ptrs.push_back(q_transpose_ptr);
sub_graph_node_ptrs.push_back(q_transpose_ptr);
bias_add_node_ptrs.push_back(q_biasadd_node_ptr);
Node* k_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name()));
if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(k_transpose_ptr);
Node* k_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name()));
if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) {
return skip_status;
}
reshape_node_ptrs[k_reshape_ptr] = 1;
sub_graph_node_ptrs.push_back(k_reshape_ptr);
Node* k_add_ptr = const_cast<Node*>(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name()));
if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(k_add_ptr);
bias_add_node_ptrs.push_back(k_add_ptr);
Node* k_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name()));
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(k_matmul_ptr);
Node* k_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name()));
if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(k_weight_transpose_ptr);
weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr);
Node* v_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name()));
if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(v_transpose_ptr);
Node* v_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name()));
if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) {
return skip_status;
}
reshape_node_ptrs[v_reshape_ptr] = 1;
sub_graph_node_ptrs.push_back(v_reshape_ptr);
Node* v_add_ptr = const_cast<Node*>(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name()));
if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(v_add_ptr);
bias_add_node_ptrs.push_back(v_add_ptr);
Node* v_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name()));
if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(v_matmul_ptr);
Node* v_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name()));
if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) {
return skip_status;
}
sub_graph_node_ptrs.push_back(v_weight_transpose_ptr);
weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr);
// K and V matmul must have the same input
Node* q_matmul_ptr = &node;
if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) {
return skip_status;
}
// Check the constant value in the Reshape nodes.
bool is_reshape_valid = true;
for (auto x : reshape_node_ptrs) {
Node* node_ptr = x.first;
int64_t idx = x.second;
auto shape_arg = node_ptr->MutableInputDefs()[1];
const ONNX_NAMESPACE::TensorProto* tensor;
if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) {
is_reshape_valid = false;
break;
}
auto data_type = tensor->data_type();
if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
is_reshape_valid = false;
break;
}
// The number of the values should be more than idx, and the idx'th value should be divisible by parallel size,
// i.e., the attention head number should be divisible by parallel size.
auto init_const = std::make_unique<Initializer>(*tensor, graph.ModelPath());
if (init_const->size() <= idx) {
is_reshape_valid = false;
break;
}
const int64_t* val = init_const->data<int64_t>();
if (val[idx] % horizontal_parallel_size_ != 0) {
LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx]
<< " is not divisible by horizontal_parallel_size_ "
<< horizontal_parallel_size_ << ", not supported currently.";
is_reshape_valid = false;
break;
}
}
if (!is_reshape_valid) {
return skip_status;
}
// Partition weights. If any of them fails, skip transforming the rest.
std::vector<ONNX_NAMESPACE::TensorProto> qkv_weight_initializer_partitions;
for (auto trans_ptr : weight_transpose_node_ptrs) {
auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition;
if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) {
break;
}
qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition);
}
// Partition bias. If any of them fails, skip transforming the rest.
std::vector<ONNX_NAMESPACE::TensorProto> qkv_bias_initializer_partitions;
for (auto add_ptr : bias_add_node_ptrs) {
auto qkv_bias_arg = add_ptr->MutableInputDefs()[1];
ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition;
if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) {
break;
}
qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition);
}
// if all the weights or biases weren't transformed, skip transforming this subgraph
if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) {
return skip_status;
}
if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) {
return skip_status;
}
// transform the dense weight. If it fails, skip transforming this subgraph.
Node* last_transpose = const_cast<Node*>(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name()));
auto dense_weight_arg = last_transpose->MutableInputDefs()[0];
ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition;
if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) {
return skip_status;
}
// Ready to transform the sub-graph when reach here.
// Replace node inputs
size_t i = 0;
for (auto trans_ptr : weight_transpose_node_ptrs) {
auto weight_name = trans_ptr->MutableInputDefs()[0]->Name();
NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]);
graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg);
graph.RemoveInitializedTensor(weight_name);
updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()});
i++;
}
i = 0;
for (auto add_ptr : bias_add_node_ptrs) {
auto bias_name = add_ptr->MutableInputDefs()[1]->Name();
NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]);
graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg);
graph.RemoveInitializedTensor(bias_name);
updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()});
i++;
}
NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg);
graph.RemoveInitializedTensor(dense_weight_arg->Name());
updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});
// It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching.
std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(),
std::back_inserter(nodes_to_clear_shape),
[](Node* node_ptr) { return node_ptr != nullptr; });
// Change the constant for the reshape nodes.
for (auto x : reshape_node_ptrs) {
Node* node_ptr = x.first;
int64_t idx = x.second;
auto shape_arg = node_ptr->MutableInputDefs()[1];
const ONNX_NAMESPACE::TensorProto* tensor;
graph.GetInitializedTensor(shape_arg->Name(), tensor);
auto data_type = tensor->data_type();
auto init_const = std::make_unique<Initializer>(*tensor, graph.ModelPath());
const int64_t* val = init_const->data<int64_t>();
int64_t size = init_const->size();
ONNX_NAMESPACE::TensorProto tensor_partition;
tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name()));
tensor_partition.set_data_type(data_type);
tensor_partition.add_dims(size);
std::vector<int64_t> val_partition;
val_partition.reserve(size);
val_partition.insert(val_partition.end(), val, val + size);
val_partition[idx] /= horizontal_parallel_size_;
tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t));
NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition);
graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition);
graph.RemoveInitializedTensor(shape_arg->Name());
}
if (dropout_node_ptr != nullptr) {
dropout_nodes_to_transform.insert(dropout_node_ptr);
}
// Add MegatronF before the 1st MatMul and MegatronG before the last Add.
NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0];
std::vector<Node*> new_consumer_nodes;
const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name());
for (auto& n : node_consumers) {
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
continue;
}
new_consumer_nodes.emplace_back(const_cast<Node*>(n));
}
bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0;
//then for q, and k&v will have different MegatronF node.
{
const std::vector<NodeArg*> sa_f_input_defs{prev_input_node_ptr};
auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto();
auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info);
Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"),
"MegatronF",
k_matmul_ptr->Name() + " BARTAttention MegatronF",
sa_f_input_defs,
{&sa_f_out_arg}, {}, kMSDomain);
sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
if (shared_same_input) {
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
}
new_consumer_nodes.push_back(&sa_f_node);
}
graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes);
counter++;
if (!shared_same_input) {
{
NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0];
std::vector<Node*> q_new_consumer_nodes;
const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name());
for (auto& n : q_node_consumers) {
if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
continue;
}
q_new_consumer_nodes.emplace_back(const_cast<Node*>(n));
}
const std::vector<NodeArg*> q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]};
auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto();
auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info);
Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"),
"MegatronF",
q_matmul_ptr->Name() + " BARTAttention MegatronF",
q_sa_f_input_defs,
{&q_sa_f_out_arg}, {}, kMSDomain);
q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0]));
q_new_consumer_nodes.push_back(&q_sa_f_node);
graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes);
// todo: need update the consumer node for the input_node as well.
}
}
const std::vector<NodeArg*> sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]};
auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto(); // copy
auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info);
Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"),
"MegatronG",
"BARTAttention MegatronG",
sa_g_input_defs,
{&sa_g_out_arg}, {}, kMSDomain);
sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0);
modified = true;
return Status::OK();
}