bool FuseBinaryOpsAfterUnpack()

in tensorflow/tensorflow/core/grappler/optimizers/gemm_optimizer.cc [283:595]


bool FuseBinaryOpsAfterUnpack(Graph* graph) {
  static int count = 0;
  VLOG(2) << "FuseBinaryOpsAfterUnpack";
  bool changed = false;
  std::vector<Node*> nodes(graph->num_nodes());
  int i = 0;
  for (Node* node : graph->nodes()) {
    nodes[i++] = node;
  }

  const std::unordered_set<string> binary_op_set = GetBinaryOps();
  for (Node* node : nodes) {
    if (!graph->IsValidNode(node).ok()) continue;
    if (node->type_string() != "Unpack") continue;
    // group ops based on its inputs
    std::vector<Node*> binary_ops;
    std::map<string, std::vector<Node*>> binary_ops_m;
    std::vector<std::pair<Node*, int>> common_extra_input;
    const Node* another_node = nullptr;
    bool can_fuse = true;
    string binary_type;
    for (Node* out : node->out_nodes()) {
      string out_type = out->type_string();
      if (binary_type.empty()) {
        binary_type = out->type_string();
        if (binary_op_set.find(binary_type) == binary_op_set.end()) {
          can_fuse = false;
          break;
        }
      }
      if (out->type_string() != binary_type) {
        can_fuse = false;
        break;
      }

      if (out->num_inputs() < 2) {
        can_fuse = false;
        break;
      }

      bool has_same_another_input = true;
      bool another_input_from_unpack_or_const = true;
      bool has_common_extra_input = true;
      for (int i = 0; i < out->num_inputs(); ++i) {
        const Edge* edge = nullptr;
        out->input_edge(i, &edge);
        Node* n = edge->src();
        if (i < 2) {
          if (n != node) {
            if (another_node == nullptr) {
              string type = n->type_string();
              if (type == "Unpack" || type == "Const") {
                another_node = n;
              } else {
                another_input_from_unpack_or_const = false;
                break;
              }
            } else {
              if (n != another_node) {
                has_same_another_input = false;
              }
            }
          }
        } else {
          size_t vec_idx = i - 2;
          if (vec_idx >= common_extra_input.size()) {
            common_extra_input.emplace_back(std::make_pair(n, edge->src_output()));
          } else {
            if (n != common_extra_input[vec_idx].first ||
                edge->src_output() != common_extra_input[vec_idx].second) {
              has_common_extra_input = false;
            }
          }
        }
      }

      if (has_same_another_input && another_input_from_unpack_or_const && has_common_extra_input)
        binary_ops.push_back(out);
    }

    if (!can_fuse || binary_ops.size() < 2) continue;

    for (Node* b : binary_ops) {
      for (int i = 0; i < std::min(b->num_inputs(), 2); ++i) {
        const Node* n = nullptr;
        b->input_node(i, &n);
        if (n != node) {
          string key;
          if (another_node->type_string() == "Unpack") {
            key = n->name();
          } else {
            TensorShapeProto s = n->def().attr().at("value").
                                 tensor().tensor_shape();
            for (int i = 0; i < s.dim_size(); i++) {
              key += std::to_string(s.dim(i).size()) + ",";
            }
          }
          if (binary_ops_m.find(key) != binary_ops_m.end()) {
            binary_ops_m[key].push_back(b);
          } else {
            std::vector<Node*> ins;
            ins.push_back(b);
            binary_ops_m[key] = ins;
          }
          break;
        }
      }
    }
    VLOG(2) << "FuseBinaryOpsAfterUnpack: found pattern";

    // Fuse binary_ops in each group.
    // For each group, do:
    // (1) add two Pack nodes to stack inputs on both sides respectively,
    // (2) add a new node to replace old binary_ops, and
    // (3) add a Unpack node to split result.
    DataType dtype = node->output_type(0);
    std::map<string, std::vector<Node*>>::iterator iter;
    iter = binary_ops_m.begin();
    while (iter != binary_ops_m.end()) {
      std::vector<Node *> *binary_ops_group = &(iter->second);
      if (binary_ops_group->size() < 2) iter++;
      std::sort(binary_ops_group->begin(), binary_ops_group->end(),
                [node](Node *a, Node *b) {
                  const Edge *e = nullptr;
                  a->input_edge(0, &e);
                  int a_src_output = e->src_output();
                  b->input_edge(0, &e);
                  int b_src_output = e->src_output();
                  return a_src_output < b_src_output;
                });
      std::vector<const Edge *> inputs[2];
      VLOG(2) << "the following ops are fused: ";
      for (Node *b : *binary_ops_group) {
        for (int i = 0; i < 2; ++i) {
          VLOG(2) << b->name();

          const Edge *e = nullptr;
          b->input_edge(i, &e);
          inputs[e->dst_input()].push_back(e);
        }
      }

      // Add two Pack nodes to group on two sides, respectively
      Node *packs[2];
      string pack_names[2];
      string prefix = "GemmOptimizer/FuseBinaryOpsAfterUnpack/" +
                      std::to_string(count++);
      pack_names[0] = prefix + "/Pack_0";
      pack_names[1] = prefix + "/Pack_1";
      Status status;
      for (int i = 0; i < 2; i++) {
        std::vector<NodeDefBuilder::NodeOut> pack_inputs;
        for (const Edge *e : inputs[i]) {
          string s = e->src()->name();
          pack_inputs.emplace_back(s, e->src_output(), dtype);
        }
        NodeDefBuilder pack_builder(pack_names[i], "Pack");
        pack_builder.Input(pack_inputs);
        NodeDef pack_node;
        status =
            pack_builder
                .Attr("N", (int) inputs[i].size())
                .Attr("T", dtype)
                .Attr("axis", 0)
                .Finalize(&pack_node);
        if (!status.ok()) {
          LOG(ERROR) << "Pack node construction failed with" << status;
          return false;
        }
        pack_node.set_device(inputs[i][0]->src()->def().device());
        packs[i] = graph->AddNode(pack_node, &status);
        if (!status.ok()) {
          LOG(ERROR) << "Adding node failed " << status;
          return false;
        }
        packs[i]->set_assigned_device_name(
            inputs[i][0]->src()->assigned_device_name());
        for (unsigned int j = 0; j < inputs[i].size(); j++) {
          graph->AddEdge(inputs[i][j]->src(), inputs[i][j]->src_output(),
                         packs[i], j);
        }
      }

      // Add a new BatchMatMulV2
      std::vector<NodeDefBuilder::NodeOut> binary_op_inputs;
      binary_op_inputs.emplace_back(pack_names[0], 0, dtype);
      binary_op_inputs.emplace_back(pack_names[1], 0, dtype);
      if (!common_extra_input.empty()) {
        for (size_t i = 0; i < common_extra_input.size(); ++i) {
          const Node *extra_node = common_extra_input[i].first;
          int extra_src_idx = common_extra_input[i].second;
          binary_op_inputs.emplace_back(extra_node->name(), extra_src_idx, extra_node->output_type(extra_src_idx));
        }
      }
      string binary_op_name = prefix;
      string type = (*binary_ops_group)[0]->type_string();
      string new_type;
      if (type == "MatMul" ||
          type == "BatchMatMul" ||
          type == "BatchMatMulV2") {
        binary_op_name += "/BatchMatMulV2";
        new_type = "BatchMatMulV2";
      } else if (type == "IndicatorMatMul") {
        binary_op_name += "/ParallelIndicatorMatMul";
        new_type = "ParallelIndicatorMatMul";
      } else {
        binary_op_name += "/" + type;
        new_type = type;
      }
      NodeDefBuilder binary_op_builder(binary_op_name, new_type);
      for (size_t i = 0; i < binary_op_inputs.size(); ++i) {
        binary_op_builder.Input(binary_op_inputs[i]);
      }
      NodeDef binary_op_node;
      bool transpose_a = false;
      bool transpose_b = false;
      if (type == "MatMul") {
        transpose_a = (*binary_ops_group)[0]->def().attr().at("transpose_a").b();
        transpose_b = (*binary_ops_group)[0]->def().attr().at("transpose_b").b();
      } else if (type == "BatchMatMul" || type == "BatchMatMulV2" || type == "IndicatorMatMul") {
        transpose_a = (*binary_ops_group)[0]->def().attr().at("adj_x").b();
        transpose_b = (*binary_ops_group)[0]->def().attr().at("adj_y").b();
      }

      if (type == "MatMul" ||
          type == "BatchMatMul" ||
          type == "BatchMatMulV2") {
        status =
            binary_op_builder
                .Attr("adj_x", transpose_a)
                .Attr("adj_y", transpose_b)
                .Attr("T", dtype)
                .Finalize(&binary_op_node);

      } else if (type == "IndicatorMatMul") {
        status =
            binary_op_builder
                .Attr("adj_x", transpose_a)
                .Attr("adj_y", transpose_b)
                .Attr("parallel_num", (int) (*binary_ops_group).size())
                .Attr("T", dtype)
                .Finalize(&binary_op_node);
      } else {
        status =
            binary_op_builder
                .Attr("T", dtype)
                .Finalize(&binary_op_node);
      }
      if (!status.ok()) {
        LOG(ERROR) << "BatchMatMulV2 node construction failed with" << status;
        return false;
      }
      binary_op_node.set_device((*binary_ops_group)[0]->def().device());
      Node* binary_op = graph->AddNode(binary_op_node, &status);
      if (!status.ok()) {
        LOG(ERROR) << "Adding node failed " << status;
        return false;
      }
      binary_op->set_assigned_device_name((*binary_ops_group)[0]->
                                          assigned_device_name());
      graph->AddEdge(packs[0], 0, binary_op, 0);
      graph->AddEdge(packs[1], 0, binary_op, 1);

      // Add an Unpack node to split result
      string unpack_name = prefix + "/Unpack" ;
      NodeDefBuilder::NodeOut unpack_input(binary_op_name, 0, dtype);
      NodeDefBuilder unpack_builder(unpack_name, "Unpack");
      unpack_builder.Input(unpack_input);
      NodeDef unpack_node;
      status =
          unpack_builder
              .Attr("num", (int)(*binary_ops_group).size())
              .Attr("T", dtype)
              .Attr("axis", 0)
              .Finalize(&unpack_node);
      if (!status.ok()) {
        LOG(ERROR) << "Unpack node construction failed with" << status;
        return false;
      }
      unpack_node.set_device((*binary_ops_group)[0]->def().device());
      Node* unpack = graph->AddNode(unpack_node, &status);
      if (!status.ok()) {
        LOG(ERROR) << "Adding node failed " << status;
        return false;
      }
      unpack->set_assigned_device_name((*binary_ops_group)[0]->
                                       assigned_device_name());
      graph->AddEdge(binary_op, 0, unpack, 0);
   
      // Add edges to forward split results to nodes after original binary_ops,
      // and remove original binary_ops
      int index = 0;
      for (Node* b : *binary_ops_group) {
        std::vector<Node*> dst_nodes;
        std::vector<int> dst_inputs;
        for (const Edge* e : b->out_edges()) {
          dst_nodes.push_back(e->dst());
          dst_inputs.push_back(e->dst_input());
        }
        for (unsigned int i = 0; i < dst_nodes.size(); i++) {
          graph->UpdateEdge(unpack, index, dst_nodes[i], dst_inputs[i]);
        }
        graph->RemoveNode(b);
        index++;
      }
 
      changed = true;
      iter++;
    }
  }
 
  return changed;
}