Status SegmentGraph()

in tensorflow/tensorflow/compiler/tf2tensorrt/segment/segment.cc [734:1132]


Status SegmentGraph(const Graph* tf_graph,
                    const grappler::GraphProperties* graph_properties,
                    const std::function<Status(const Node*, const std::unordered_set<string> &target_nodes)>& candidate_fn,
                    const std::function<bool(const Edge*)>& input_candidate_fn,
                    const std::function<bool(const Edge*)>& output_candidate_fn,
                    const SegmentOptions& options, SegmentVector* segments) {
  if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
    return errors::Internal(
        "Explicit batch mode should allow dynamic non-batch dimensions");
  }

  if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
    return errors::Internal("Implicit batch mode requires maximum_batch_size");
  }

  if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
    return errors::Internal(
        "Need graph propertities to disallow dynamic non-batch dimensions");
  }

  // Steps:
  // 1. run the segmentation algorithm to find all the segments, which uses
  //    candidate_fn to determine the candidates segment nodes;
  // 2. for each segments, remove the nodes that are inputs/outputs of the
  //    segment but are not eligible, using input/output_candidate_fn to
  //    determine the eligibilities;
  // 3. convert the segment into expected return format and return the result.

  // get target convert nodes
  std::unordered_set<string> target_nodes;
  SearchNodesWithRanges(tf_graph, options.convert_ranges, target_nodes);

  // --------------------------------- Step 1 ---------------------------------
  auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
  // Use a union-find to collect the nodes that belong to the same
  // segment. A node value of nullptr indicates that the node is not a candidate
  // for TRT.
  std::map<string, int> unsupported_ops_map = {};

  // Getting the operations denylisted for conversion
  string tftrt_op_denylist_str;
  TF_CHECK_OK(
      ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));

  auto tftrt_op_denylist = gtl::FlatSet<string>{};  // non-absl ok

  for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
    tftrt_op_denylist.insert(x);
  }

  // Parsing each node of the graph
  std::vector<UnionFind<SimpleNode*>> node_segments;
  for (int i = 0; i < graph->num_node_ids(); ++i) {
    SimpleNode* node = graph->FindNodeId(i);
    if (!node) {
      VLOG(3) << "Node " << i << " doesn't exist in the graph";
      continue;
    }
    auto exclude_node = [&](absl::string_view reason) {
      LOG(INFO) << "Not a TF-TRT candidate, "
              << "(Op type: " << node->tf_node()->type_string() << "), "
              << "(Op name: " << node->name() << "), "
              << "(Reason: " << reason << ")";
      unsupported_ops_map[node->tf_node()->type_string()]++;
      node = nullptr;
    };
    absl::optional<DeviceNameUtils::ParsedName> device_name =
        GetDeviceParsedName(node->tf_node());
    // GetDeviceParseName capitalizes the device type.
    if (!device_name.has_value() ||
        (device_name->has_type && device_name->type != "GPU")) {
      exclude_node("node can't be placed on GPU");
    } else if (options.exclude_node_list.count(node->name()) != 0) {
      exclude_node("excluded by segmenter option");
    } else if (options.use_implicit_batch &&
               !OperationCanBeTranslatedToImplicitBatch(graph_properties,
                                                        node->tf_node())) {
      exclude_node(
          "implicit batch mode requires input shape with at least two "
          "dimensions");
    } else if (!options.allow_dynamic_non_batch_dim &&
               OperationHasDynamicNonBatchDimension(graph_properties,
                                                    node->tf_node())) {
      exclude_node("dynamic non-batch dimensions not allowed");
    } else {
      const Status status = candidate_fn(node->tf_node(), target_nodes);
      if (!status.ok()) {
        exclude_node(status.error_message());
      } else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
        // WARNING verbosity since the user explicitly requests this behavior.
        LOG_WARNING_WITH_PREFIX
            << "Denylisted as TF-TRT candidate, "
            << "(Op type: " << node->tf_node()->type_string() << "), "
            << "(Op name: " << node->name() << ")";
        exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
      } else {
        VLOG(2) << "Accepted as a TF-TRT candidate, "
                << "(Op type: " << node->tf_node()->type_string() << "), "
                << "(Op name: " << node->name();
      }
    }
    AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
                      options.use_implicit_batch);
  }
  string unsupported_op_report =
      StrCat("\n", string(80, '#'), "\n",
             "TensorRT unsupported/unconverted OP Report:");
  int total_unconverted_ops{0};

  // Copy key-value pair from unsupported_ops_map to vector of pairs
  std::vector<std::pair<std::string, int>> _vect;
  for (auto& _it : unsupported_ops_map) {
    _vect.push_back(_it);
  }

  // Sort in descending order using the number of uses of the OP that are not
  // converted.
  std::sort(_vect.begin(), _vect.end(),
            [](const std::pair<std::string, int>& _a,
               const std::pair<std::string, int>& _b) -> bool {
              return _a.second > _b.second;
            });

  for (auto& _it : _vect) {
    unsupported_op_report = StrCat(unsupported_op_report, "\n\t- ", _it.first,
                                   " -> ", _it.second, "x");
    total_unconverted_ops += _it.second;
  }

  unsupported_op_report =
      StrCat(unsupported_op_report, "\n", string(80, '-'),
             "\n\t - Total unconverted OPs: ", total_unconverted_ops,
             "\n\t - Total unconverted OP Types: ", unsupported_ops_map.size(),
             "\nFor more information see https://docs.nvidia.com/deeplearning",
             "/frameworks/tf-trt-user-guide/index.html#supported-ops.", "\n",
             string(80, '#'));
  LOG(INFO) << unsupported_op_report;

  // The segmentation algorithm below visits nodes in reverse topological order
  // and attempts to merge nodes along output edges. That means that subgraphs
  // grow from the output-side of the network towards the inputs.
  //
  // In general this is not guaranteed to produce a globally optimal
  // segmentation. For example, consider graph with node {A, B, C, D} and edges
  // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
  // in theory we can choose to contract either A, B or B, D but not both, but
  // here it always choose to contract B, D.
  //
  // In the future if we have a measure of how beneficial it is to include a
  // given node in a TRT subgraph then we can revisit this algorithm to take
  // advantage of that information.
  std::vector<const SimpleNode*> order;
  order.reserve(graph->num_node_ids());
  StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
            /*enter=*/nullptr, [&order](const SimpleNode* n) {
              order.push_back(n);
              return true;
            });
  for (const SimpleNode* node : order) {
    // All output nodes of 'node' have been visited.
    VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
    // 'node' must be a TRT candidate.
    if (node_segments[node->id()].Value() == nullptr) {
      VLOG(3) << "... not a TRT candidate";
      continue;
    }
    // Contract output edges to combine 'node' with output nodes. Repeat this
    // step until no output edges can be further contracted. This is because
    // contracting an output edge may unblock new edges for contracting.
    ClusterBatchSize expected_batch_size =
        node_segments[node->id()].Property().BatchSize();
    DeviceNameUtils::ParsedName expected_device_name =
        node_segments[node->id()].Property().DeviceName();
    VLOG(3) << "batch size " << expected_batch_size;
    while (true) {
      std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
      // TODO(bixia): consider merging the loop to find the edges and the loop
      // to contract the edges.
      for (const SimpleEdge* out_edge : node->out_edges()) {
        VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
                << out_edge->dst()->id() << " <- " << node->id() << " )";
        if (out_edge->IsControlEdge()) {
          VLOG(3) << "... ... Control Edge, Skipping";
          continue;
        }
        UnionFind<SimpleNode*>* out_cluster =
            &node_segments[out_edge->dst()->id()];
        // Out node must be a TRT candidate.
        if (out_cluster->Value() == nullptr) {
          VLOG(3) << "... ... not a TRT candidate";
          continue;
        }
        // Out node must have compatible batch size.
        ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
        ClusterBatchSize merged_batch_size = expected_batch_size;
        if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
          VLOG(3) << "... ... incompatible batch sizes "
                  << expected_batch_size.ToString() << " "
                  << out_batch_size.ToString();
          continue;
        }

        const DeviceNameUtils::ParsedName& out_device_name =
            out_cluster->Property().DeviceName();
        absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
            MergeIfCompatible(expected_device_name, out_device_name);
        if (!merged_device_name.has_value()) {
          VLOG(3) << "... ... incompatible device names "
                  << expected_device_name << " " << out_device_name;
          continue;
        }

        if (CanContractEdge(out_edge, graph)) {
          VLOG(3) << "... ... can contract. new batch size "
                  << merged_batch_size.ToString();
          contract_edges.insert(out_edge);
          expected_batch_size = merged_batch_size;
          expected_device_name = *merged_device_name;
        } else {
          VLOG(3) << "... ... cannot contract, would form cycle";
        }
      }
      if (contract_edges.empty()) {
        break;
      }
      // Contract edges and collect the adjacent nodes into the same
      // segment/subgraph.
      while (!contract_edges.empty()) {
        const SimpleEdge* contract_edge = *contract_edges.begin();
        const SimpleNode* src = contract_edge->src();
        const SimpleNode* dst = contract_edge->dst();

        VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
                << src->id() << " <- " << dst->id();
        TF_RETURN_IF_ERROR(
            node_segments[src->id()].Merge(&node_segments[dst->id()]));

        // Contracting the edge leaves disconnected graph edges.
        // Remove these from the graph and from 'contract_edges' so we
        // don't visit them again.
        SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
        std::vector<const SimpleEdge*> remove_edges;
        ContractEdge(e, graph.get(), &remove_edges);

        for (const SimpleEdge* r : remove_edges) {
          contract_edges.erase(r);
          graph->RemoveEdge(r);
        }
      }
      if (expected_batch_size !=
          node_segments[node->id()].Property().BatchSize()) {
        return errors::Internal(
            "expected batch size is not the same as the actual batch size");
      }
      if (!(expected_device_name ==
          node_segments[node->id()].Property().DeviceName())) {
        return errors::Internal(
            "expected device name is not the same as the actual device name");
      }
    }
  }

  // Collect the segments/subgraphs. Each subgraph is represented by a
  // set of the names of the nodes in that subgraph.

  // A map from the segment identifier (currently the name of the root node of
  // the segment tree) to the segment nodes set.
  std::map<string, Segment> sg_map;

  for (auto& u : node_segments) {
    if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
      sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
    }
    if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
      sg_map[u.Value()->name()].property = u.Property();
    }
  }

  // --------------------------------- Step 2 ---------------------------------
  // Remove ineligible input/output nodes.
  for (auto& itr : sg_map) {
    std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
    VLOG(1) << "Segment original size: " << segment_nodes.size();
    while (true) {
      std::deque<const Node*> in_nodes_que, out_nodes_que;
      // Find an input node that is not eligible and add it to the queue.
      // Nodes that has no incoming edges should not be treated as "input",
      // as there are really no inputs to them. Similar for output nodes.
      for (auto node : segment_nodes) {
        bool added = false;
        for (const Edge* edge : node->in_edges()) {
          if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
              !segment_nodes.count(edge->src())) {  // 'node' is an input node.
            if (!input_candidate_fn(edge)) {
              in_nodes_que.push_back(node);
              added = true;
              break;
            }
          }
        }
        if (added) continue;  // Only adding the node once to either queue.
        for (const Edge* edge : node->out_edges()) {
          if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
              !segment_nodes.count(edge->dst())) {  // 'node' is an output node.
            if (!output_candidate_fn(edge)) {
              out_nodes_que.push_back(node);
              break;
            }
          }
        }
      }
      if (in_nodes_que.empty() && out_nodes_que.empty()) {
        // No more ineligible input/output nodes.
        break;
      }
      // Now for each ineligible node, remove all of its inputs or outputs from
      // the subgraph.
      //
      // It can be proven that, if the original subgraph:
      // 1. is a DAG, and
      // 2. all paths between two nodes in the subgraph are all inside the
      //    subgraph
      // then after doing this operation the resulting subgraph will keep the
      // same properties 1 and 2.
      //
      // For simplicity we use heuristics: for input and const output nodes
      // remove all their inputs, and for non-const output nodes remove all
      // their outputs. In this way, for common cases the number of removed
      // nodes should be minimum.
      auto remove_nodes = [&segment_nodes](bool is_input_nodes,
                                           std::deque<const Node*>* que) {
        // Run a BFS on the queue to find all the input/output nodes.
        std::set<const Node*, NodePtrCompare> visited;
        std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
        while (!que->empty()) {
          auto node = que->front();
          que->pop_front();
          if (!visited.insert(node).second) continue;
          segment_nodes.erase(node);
          for (auto in : (is_input_nodes || node->type_string() == "Const")
                             ? node->in_nodes()
                             : node->out_nodes()) {
            if (segment_nodes.count(in)) {
              que->push_back(in);
              if (VLOG_IS_ON(2)) {
                if (!logged.count(in)) {
                  VLOG(2) << "----> Need to remove node " << in->name()
                          << " because one of its "
                          << (is_input_nodes ? "output" : "input")
                          << " nodes in the graph was removed: "
                          << node->name();
                  logged.insert(in);
                }
              }
            }
          }
        }
      };
      remove_nodes(true, &in_nodes_que);
      remove_nodes(false, &out_nodes_que);
    }
    VLOG(1) << "Segment new size: " << segment_nodes.size();
  }

  // --------------------------------- Step 3 ---------------------------------
  // Convert the segments into the expected return format
  for (const auto& itr : sg_map) {
    const string& segment_root = itr.first;
    // Return format does not require set comparator.
    std::set<const Node*, NodePtrCompare> segment_nodes(
        itr.second.nodes.begin(), itr.second.nodes.end());
    if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
      string s;
      for (auto node : segment_nodes) {
        StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
      }
      VLOG(1) << "Nodes in segment " << segments->size()
              << " with parent=" << segment_root << ":" << s;
    }

    const int num_effective_nodes = std::count_if(
        segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
          static auto noops =
              new std::set<string>{"Identity", "Snapshot", "StopGradient"};
          return noops->count(node->type_string()) == 0;
        });

    // Don't use segments whose number of effective nodes is small.
    if (num_effective_nodes == 0 ||
        num_effective_nodes < options.minimum_segment_size) {
      LOG(INFO) << "Segment " << segments->size() << " has only "
              << num_effective_nodes << " effective nodes, dropping";
      continue;
    }
    segments->emplace_back(itr.second.property, segment_nodes);
  }

  return Status::OK();
}