Status MegatronTransformer::TransformBARTAttention()

in orttraining/orttraining/core/optimizer/megatron_transformer.cc [897:1268]


Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified,
                                                   std::vector<Node*>& nodes_to_clear_shape,
                                                   std::unordered_set<Node*>& dropout_nodes_to_transform,
                                                   int32_t& counter,
                                                   NodeIndex node_index) const {
  auto skip_status = common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Skip BART Attention megatron transformation");

  // Self/Enc-Dec Attention sub-graph.
  //
  // MatMul->Add->Mul->Reshape->Transpose->MatMul->Reshape->Where->Reshape->Softmax->Dropout->MatMul->Transpose->Reshape->MatMul->Add->Droupout
  // MatMul->Add->Reshape->Transpose-------> |                                                  |
  // MatMul->Add->Reshape->Transpose----------------------------------------------------------> |
  auto& node = *graph.GetNode(node_index);

  if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", opset_v9_13) ||
      !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) ||
      node.GetOutputEdgesCount() != 1) {
    return skip_status;
  }

  Node* q_matmul_input_node_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[0]->Name()));
  if (q_matmul_input_node_ptr != nullptr && q_matmul_input_node_ptr->OpType().compare("MegatronF") == 0) {
    return skip_status;
  }
  std::vector<Node*> sub_graph_node_ptrs;
  sub_graph_node_ptrs.push_back(&node);
  ProviderType provider_type = node.GetExecutionProviderType();

  std::vector<NodeInfo> linear_pattern = {
      NodeInfo({add_info}),
      NodeInfo({mul_info}),
      NodeInfo({reshape_info}),
      NodeInfo({transpose_info}),
      NodeInfo({matmul_info}),
      NodeInfo({add_info}, false),  // -13
      NodeInfo({reshape_info}),
      NodeInfo({where_info}),
      NodeInfo({reshape_info}),
      NodeInfo({softmax_info}),
      NodeInfo({dropout_info}, false),  // -8
      NodeInfo({matmul_info}),
      NodeInfo({add_info}, false),  // -6
      NodeInfo({transpose_info}),
      NodeInfo({reshape_info}),
      NodeInfo({matmul_info}),  // -3
      NodeInfo({add_info}),
      NodeInfo({dropout_info}, false)};  // -1
  if (!MatchLinearPattern(graph, &node, provider_type, linear_pattern, sub_graph_node_ptrs)) {
    return skip_status;
  }
  // Get all useful nodes here as more vector push back below will change the index.
  // Other than the optional nodes in the pattern, all other node pointers are valid
  // if they match the linear pattern.
  Node* q_biasadd_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 18];
  Node* q_transpose_after_reshape_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 15];
  Node* qk_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 14];
  Node* dropout_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 8];
  Node* qkv_matmul_node_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 7];
  Node* transpose_node1_ptr = sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 5];
  Node& dense_matmul_node = *sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 3];

  // Transpose node attribute checking.
  if (!optimizer_utils::IsAttributeWithExpectedValues(*q_transpose_after_reshape_node_ptr, "perm", {1LL, 0LL, 2LL}) ||
      !optimizer_utils::IsAttributeWithExpectedValues(*transpose_node1_ptr, "perm", {1LL, 0LL, 2LL})) {
    return skip_status;
  }
  // map between reshape node and dim of reshape that must be modified
  std::unordered_map<Node*, int64_t> reshape_node_ptrs;
  reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 16]] = 1;
  reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 12]] = 1;
  reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 10]] = 0;
  reshape_node_ptrs[sub_graph_node_ptrs[sub_graph_node_ptrs.size() - 4]] = 2;
  // till now node should be q matmul operation

  std::vector<Node*> weight_transpose_node_ptrs;
  std::vector<Node*> bias_add_node_ptrs;

  Node* q_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(node.MutableInputDefs()[1]->Name()));
  if (q_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*q_transpose_ptr, transpose_info, provider_type)) {
    return skip_status;
  }
  weight_transpose_node_ptrs.push_back(q_transpose_ptr);
  sub_graph_node_ptrs.push_back(q_transpose_ptr);
  bias_add_node_ptrs.push_back(q_biasadd_node_ptr);

  Node* k_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qk_matmul_node_ptr->MutableInputDefs()[1]->Name()));
  if (k_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_transpose_ptr, transpose_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(k_transpose_ptr);

  Node* k_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(k_transpose_ptr->MutableInputDefs()[0]->Name()));
  if (k_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*k_reshape_ptr, reshape_info, provider_type)) {
    return skip_status;
  }
  reshape_node_ptrs[k_reshape_ptr] = 1;
  sub_graph_node_ptrs.push_back(k_reshape_ptr);

  Node* k_add_ptr = const_cast<Node*>(graph.GetProducerNode(k_reshape_ptr->MutableInputDefs()[0]->Name()));
  if (k_add_ptr == nullptr || !IsExpectedOpAndProvider(*k_add_ptr, add_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(k_add_ptr);
  bias_add_node_ptrs.push_back(k_add_ptr);

  Node* k_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(k_add_ptr->MutableInputDefs()[0]->Name()));
  if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(k_matmul_ptr);

  Node* k_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(k_matmul_ptr->MutableInputDefs()[1]->Name()));
  if (k_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*k_weight_transpose_ptr, transpose_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(k_weight_transpose_ptr);
  weight_transpose_node_ptrs.push_back(k_weight_transpose_ptr);

  Node* v_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(qkv_matmul_node_ptr->MutableInputDefs()[1]->Name()));
  if (v_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_transpose_ptr, transpose_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(v_transpose_ptr);

  Node* v_reshape_ptr = const_cast<Node*>(graph.GetProducerNode(v_transpose_ptr->MutableInputDefs()[0]->Name()));
  if (v_reshape_ptr == nullptr || !IsExpectedOpAndProvider(*v_reshape_ptr, reshape_info, provider_type)) {
    return skip_status;
  }
  reshape_node_ptrs[v_reshape_ptr] = 1;
  sub_graph_node_ptrs.push_back(v_reshape_ptr);

  Node* v_add_ptr = const_cast<Node*>(graph.GetProducerNode(v_reshape_ptr->MutableInputDefs()[0]->Name()));
  if (v_add_ptr == nullptr || !IsExpectedOpAndProvider(*v_add_ptr, add_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(v_add_ptr);
  bias_add_node_ptrs.push_back(v_add_ptr);

  Node* v_matmul_ptr = const_cast<Node*>(graph.GetProducerNode(v_add_ptr->MutableInputDefs()[0]->Name()));
  if (k_matmul_ptr == nullptr || !IsExpectedOpAndProvider(*k_matmul_ptr, matmul_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(v_matmul_ptr);

  Node* v_weight_transpose_ptr = const_cast<Node*>(graph.GetProducerNode(v_matmul_ptr->MutableInputDefs()[1]->Name()));
  if (v_weight_transpose_ptr == nullptr || !IsExpectedOpAndProvider(*v_weight_transpose_ptr, transpose_info, provider_type)) {
    return skip_status;
  }
  sub_graph_node_ptrs.push_back(v_weight_transpose_ptr);
  weight_transpose_node_ptrs.push_back(v_weight_transpose_ptr);

  // K and V matmul must have the same input
  Node* q_matmul_ptr = &node;
  if (k_matmul_ptr->MutableInputDefs()[0]->Name() != v_matmul_ptr->MutableInputDefs()[0]->Name()) {
    return skip_status;
  }

  // Check the constant value in the Reshape nodes.
  bool is_reshape_valid = true;
  for (auto x : reshape_node_ptrs) {
    Node* node_ptr = x.first;
    int64_t idx = x.second;
    auto shape_arg = node_ptr->MutableInputDefs()[1];
    const ONNX_NAMESPACE::TensorProto* tensor;
    if (!graph.GetInitializedTensor(shape_arg->Name(), tensor)) {
      is_reshape_valid = false;
      break;
    }
    auto data_type = tensor->data_type();
    if (data_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
      is_reshape_valid = false;
      break;
    }
    // The number of the values should be more than idx, and the idx'th value should be divisible by parallel size,
    // i.e., the attention head number should be divisible by parallel size.
    auto init_const = std::make_unique<Initializer>(*tensor, graph.ModelPath());
    if (init_const->size() <= idx) {
      is_reshape_valid = false;
      break;
    }
    const int64_t* val = init_const->data<int64_t>();
    if (val[idx] % horizontal_parallel_size_ != 0) {
      LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx]
                            << " is not divisible by horizontal_parallel_size_ "
                            << horizontal_parallel_size_ << ", not supported currently.";
      is_reshape_valid = false;
      break;
    }
  }

  if (!is_reshape_valid) {
    return skip_status;
  }

  // Partition weights. If any of them fails, skip transforming the rest.
  std::vector<ONNX_NAMESPACE::TensorProto> qkv_weight_initializer_partitions;
  for (auto trans_ptr : weight_transpose_node_ptrs) {
    auto qkv_weight_arg = trans_ptr->MutableInputDefs()[0];
    ONNX_NAMESPACE::TensorProto qkv_weight_initializer_partition;
    if (!PartitionWeightByRow(graph, *qkv_weight_arg, qkv_weight_initializer_partition)) {
      break;
    }
    qkv_weight_initializer_partitions.push_back(qkv_weight_initializer_partition);
  }

  // Partition bias. If any of them fails, skip transforming the rest.
  std::vector<ONNX_NAMESPACE::TensorProto> qkv_bias_initializer_partitions;
  for (auto add_ptr : bias_add_node_ptrs) {
    auto qkv_bias_arg = add_ptr->MutableInputDefs()[1];
    ONNX_NAMESPACE::TensorProto qkv_bias_initializer_partition;
    if (!PartitionWeightByColumn(graph, *qkv_bias_arg, qkv_bias_initializer_partition)) {
      break;
    }
    qkv_bias_initializer_partitions.push_back(qkv_bias_initializer_partition);
  }

  // if all the weights or biases weren't transformed, skip transforming this subgraph
  if (weight_transpose_node_ptrs.size() != qkv_weight_initializer_partitions.size()) {
    return skip_status;
  }
  if (bias_add_node_ptrs.size() != qkv_bias_initializer_partitions.size()) {
    return skip_status;
  }

  // transform the dense weight. If it fails, skip transforming this subgraph.
  Node* last_transpose = const_cast<Node*>(graph.GetProducerNode(dense_matmul_node.MutableInputDefs()[1]->Name()));
  auto dense_weight_arg = last_transpose->MutableInputDefs()[0];
  ONNX_NAMESPACE::TensorProto dense_weight_initializer_partition;
  if (!PartitionWeightByColumn(graph, *dense_weight_arg, dense_weight_initializer_partition)) {
    return skip_status;
  }

  // Ready to transform the sub-graph when reach here.
  // Replace node inputs
  size_t i = 0;
  for (auto trans_ptr : weight_transpose_node_ptrs) {
    auto weight_name = trans_ptr->MutableInputDefs()[0]->Name();
    NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]);
    graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg);
    graph.RemoveInitializedTensor(weight_name);
    updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()});
    i++;
  }
  i = 0;
  for (auto add_ptr : bias_add_node_ptrs) {
    auto bias_name = add_ptr->MutableInputDefs()[1]->Name();
    NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]);
    graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg);
    graph.RemoveInitializedTensor(bias_name);
    updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()});
    i++;
  }

  NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition);
  graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg);
  graph.RemoveInitializedTensor(dense_weight_arg->Name());
  updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()});

  // It's possible that the node vector contains nullptr due to some optinal node infos during linear pattern matching.
  std::copy_if(sub_graph_node_ptrs.begin(), sub_graph_node_ptrs.end(),
               std::back_inserter(nodes_to_clear_shape),
               [](Node* node_ptr) { return node_ptr != nullptr; });

  // Change the constant for the reshape nodes.
  for (auto x : reshape_node_ptrs) {
    Node* node_ptr = x.first;
    int64_t idx = x.second;
    auto shape_arg = node_ptr->MutableInputDefs()[1];
    const ONNX_NAMESPACE::TensorProto* tensor;
    graph.GetInitializedTensor(shape_arg->Name(), tensor);
    auto data_type = tensor->data_type();
    auto init_const = std::make_unique<Initializer>(*tensor, graph.ModelPath());
    const int64_t* val = init_const->data<int64_t>();
    int64_t size = init_const->size();
    ONNX_NAMESPACE::TensorProto tensor_partition;
    tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name()));
    tensor_partition.set_data_type(data_type);
    tensor_partition.add_dims(size);

    std::vector<int64_t> val_partition;
    val_partition.reserve(size);
    val_partition.insert(val_partition.end(), val, val + size);
    val_partition[idx] /= horizontal_parallel_size_;
    tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t));
    NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition);
    graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition);
    graph.RemoveInitializedTensor(shape_arg->Name());
  }

  if (dropout_node_ptr != nullptr) {
    dropout_nodes_to_transform.insert(dropout_node_ptr);
  }

  // Add MegatronF before the 1st MatMul and MegatronG before the last Add.

  NodeArg* prev_input_node_ptr = k_matmul_ptr->MutableInputDefs()[0];
  std::vector<Node*> new_consumer_nodes;
  const auto& node_consumers = graph.GetConsumerNodes(prev_input_node_ptr->Name());
  for (auto& n : node_consumers) {
    if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
      continue;
    }
    new_consumer_nodes.emplace_back(const_cast<Node*>(n));
  }

  bool shared_same_input = k_matmul_ptr->MutableInputDefs()[0]->Name().compare(q_matmul_ptr->MutableInputDefs()[0]->Name()) == 0;

  //then for q, and k&v will have different MegatronF node.
  {
    const std::vector<NodeArg*> sa_f_input_defs{prev_input_node_ptr};
    auto sa_f_type_info = *prev_input_node_ptr->TypeAsProto();
    auto& sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(k_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &sa_f_type_info);
    Node& sa_f_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronF"),
                                    "MegatronF",
                                    k_matmul_ptr->Name() + " BARTAttention MegatronF",
                                    sa_f_input_defs,
                                    {&sa_f_out_arg}, {}, kMSDomain);
    sa_f_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
    graph_utils::ReplaceNodeInput(*k_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
    graph_utils::ReplaceNodeInput(*v_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
    if (shared_same_input) {
      graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(sa_f_node.MutableOutputDefs()[0]));
    }
    new_consumer_nodes.push_back(&sa_f_node);
  }
  graph.UpdateConsumerNodes(prev_input_node_ptr->Name(), new_consumer_nodes);
  counter++;
  if (!shared_same_input) {
    {
      NodeArg* q_prev_input_node_ptr = q_matmul_ptr->MutableInputDefs()[0];
      std::vector<Node*> q_new_consumer_nodes;
      const auto& q_node_consumers = graph.GetConsumerNodes(q_prev_input_node_ptr->Name());
      for (auto& n : q_node_consumers) {
        if (n->Index() == k_matmul_ptr->Index() || n->Index() == v_matmul_ptr->Index() || n->Index() == q_matmul_ptr->Index()) {
          continue;
        }
        q_new_consumer_nodes.emplace_back(const_cast<Node*>(n));
      }

      const std::vector<NodeArg*> q_sa_f_input_defs{q_matmul_ptr->MutableInputDefs()[0]};
      auto q_sa_f_type_info = *q_matmul_ptr->MutableInputDefs()[0]->TypeAsProto();
      auto& q_sa_f_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(q_matmul_ptr->Name() + "BARTAttention_MegatronF_Output"), &q_sa_f_type_info);
      Node& q_sa_f_node = graph.AddNode(graph.GenerateNodeName(q_matmul_ptr->Name() + "BARTAttention_MegatronF"),
                                        "MegatronF",
                                        q_matmul_ptr->Name() + " BARTAttention MegatronF",
                                        q_sa_f_input_defs,
                                        {&q_sa_f_out_arg}, {}, kMSDomain);
      q_sa_f_node.SetExecutionProviderType(q_matmul_ptr->GetExecutionProviderType());

      graph_utils::ReplaceNodeInput(*q_matmul_ptr, 0, *(q_sa_f_node.MutableOutputDefs()[0]));
      q_new_consumer_nodes.push_back(&q_sa_f_node);
      graph.UpdateConsumerNodes(q_prev_input_node_ptr->Name(), q_new_consumer_nodes);
      // todo: need update the consumer node for the input_node as well.
    }
  }

  const std::vector<NodeArg*> sa_g_input_defs{dense_matmul_node.MutableOutputDefs()[0]};
  auto sa_g_type_info = *dense_matmul_node.MutableOutputDefs()[0]->TypeAsProto();  // copy
  auto& sa_g_out_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("BARTAttention_MegatronG_Output"), &sa_g_type_info);
  Node& sa_g_node = graph.AddNode(graph.GenerateNodeName(k_matmul_ptr->Name() + "BARTAttention_MegatronG"),
                                  "MegatronG",
                                  "BARTAttention MegatronG",
                                  sa_g_input_defs,
                                  {&sa_g_out_arg}, {}, kMSDomain);
  sa_g_node.AddAttribute("group_type", static_cast<int64_t>(training::WorkerGroupType::HorizontalParallel));
  sa_g_node.SetExecutionProviderType(k_matmul_ptr->GetExecutionProviderType());
  graph_utils::ReplaceDownstreamNodeInput(graph, dense_matmul_node, 0, sa_g_node, 0);

  modified = true;

  return Status::OK();
}