Status SegmentGraph()

in grappler/convert/segment.cc [400:720]


Status SegmentGraph(const Graph* tf_graph,
                    const std::function<Status(const Node*)>& candidate_fn,
                    const std::function<bool(const Edge*)>& input_candidate_fn,
                    const std::function<bool(const Edge*)>& output_candidate_fn,
                    const SegmentOptions& options,
                    SegmentNodesVector* segments) {
  // Steps:
  // 1. run the segmentation algorithm to find all the segments, which uses
  //    candidate_fn to determine the candidates segment nodes;
  // 2. for each segments, remove the nodes that are inputs/outputs of the
  //    segment but are not eligible, using input/output_candidate_fn to
  //    determine the eligibilities;
  // 3. convert the segment into expected return format and return the result.

  // --------------------------------- Step 1 ---------------------------------
  auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
  // Use a union-find to collect the nodes that belong to the same
  // segment. A node value of nullptr indicates that the node is not a candidate
  // for TRT.
  std::unordered_set<string> unsupported_ops;
  int num_unsupported_ops = 0;
  std::vector<UnionFind<SimpleNode*>> node_segments;
  for (int i = 0; i < graph->num_node_ids(); ++i) {
    SimpleNode* node = graph->FindNodeId(i);
    if (options.exclude_node_list.count(node->name()) != 0) {
      VLOG(1) << "Not a TF-TRT candidate, "
              << "(Op type: " << node->tf_node()->type_string() << "), "
              << "(Op name: " << node->name() << "), "
              << "(Reason: excluded by segmenter option)";
      unsupported_ops.emplace(node->tf_node()->type_string());
      num_unsupported_ops++;
      node = nullptr;
    } else {
      const Status status = candidate_fn(node->tf_node());
      if (!status.ok()) {
        VLOG(1) << "Not a TF-TRT candidate, "
                << "(Op type: " << node->tf_node()->type_string() << "), "
                << "(Op name: " << node->name() << "), "
                << "(Reason: " << status << ")";
        unsupported_ops.emplace(node->tf_node()->type_string());
        num_unsupported_ops++;
        node = nullptr;
      } else {
        VLOG(2) << "Accepted as a TF-TRT candidate, "
                << "(Op type: " << node->tf_node()->type_string() << "), "
                << "(Op name: " << node->name();
      }
    }
    node_segments.emplace_back(node);
  }
  string msg = StrCat(
      "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
      " different types in the graph that", " are not compiled by neuron-cc: ");
  for (const auto& elem : unsupported_ops) {
    StrAppend(&msg, elem, ", ");
  }
  LOG(INFO) << msg << "(For more information see "
            << "https://awsdocs-neuron.readthedocs-hosted.com"
            << "/en/latest/release-notes/neuron-cc-ops/"
               "neuron-cc-ops-tensorflow.html).";

  // The segmentation algorithm below visits nodes in reverse topological order
  // and attempts to merge nodes along output edges. That means that subgraphs
  // grow from the output-side of the network towards the inputs.
  //
  // In general this is not guaranteed to produce a globally optimal
  // segmentation. For exaample, consider graph with node {A, B, C, D} and edges
  // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
  // in theory we can choose to contract either A, B or B, D but not both, but
  // here it always choose to contract B, D.
  //
  // In the future if we have a measure of how beneficial it is to include a
  // given node in a TRT subgraph then we can revisit this algorithm to take
  // advantage of that information.
  std::vector<const SimpleNode*> order;
  order.reserve(graph->num_node_ids());
  StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
            /*enter=*/nullptr, [&order](const SimpleNode* n) {
              order.push_back(n);
              return true;
            });
  for (const SimpleNode* node : order) {
    // All output nodes of 'node' have been visited...
    VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
    // 'node' must be a TRT candidate...
    if (node_segments[node->id()].Value() == nullptr) {
      VLOG(3) << "... not a TRT candidate";
      continue;
    }
    // Contract output edges to combine 'node' with output
    // nodes. Iterate since combining two nodes may unblock other
    // combining.
    while (true) {
      std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
      for (const SimpleEdge* out_edge : node->out_edges()) {
        VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
                << out_edge->dst()->id() << " <- " << node->id() << " )";
        if (out_edge->IsControlEdge()) {
          VLOG(3) << "... ... Control Edge, Skipping";
          continue;
        }
        // Out node must be TRT candidate...
        if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
          VLOG(3) << "... ... not a TRT candidate";
          continue;
        }
        if (CanContractEdge(out_edge, graph)) {
          VLOG(3) << "... ... can contract";
          contract_edges.insert(out_edge);
        } else {
          VLOG(3) << "... ... cannot contract, would form cycle";
        }
      }
      if (contract_edges.empty()) {
        break;
      }
      // Contract edges and collect the adjacent nodes into the same
      // segment/subgraph.
      while (!contract_edges.empty()) {
        const SimpleEdge* contract_edge = *contract_edges.begin();
        const SimpleNode* src = contract_edge->src();
        const SimpleNode* dst = contract_edge->dst();

        VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
                << src->id() << " <- " << dst->id();
        node_segments[src->id()].Merge(&node_segments[dst->id()]);

        // Contracting the edge leaves disconnected graph edges.
        // Remove these from the graph and from 'contract_edges' so we
        // don't visit them again.
        SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
        std::vector<const SimpleEdge*> remove_edges;
        ContractEdge(e, graph.get(), &remove_edges);

        for (const SimpleEdge* r : remove_edges) {
          contract_edges.erase(r);
          graph->RemoveEdge(r);
        }
      }
    }
  }

  // Collect the segments/subgraphs. Each subgraph is represented by a
  // set of the names of the nodes in that subgraph.

  // A map from the segment identifier (currently the name of the root node of
  // the segment tree) to the segment nodes set.
  std::map<string, std::set<const Node*, NodePtrCompare>> sg_map;

  // A map from the segment identifier (currently the name of the root node of
  // the segment tree) to the device names that the nodes in the segment are
  // assigned to.
  //
  // TODO(aaroey): nodes assigned to different devices should not be merged,
  // fix this.
  std::unordered_map<string, std::set<string>> device_maps;

  for (auto& u : node_segments) {
    if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
      sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
      auto tf_node = u.Value()->tf_node();
      // has_assigned_device_name() is expected to return true
      // when called from optimization pass. However, since graph
      // is converted back and forth between graph and graphdef,
      // assigned devices demoted to requested devices. If the graph
      // is passed directly to this module, assigned devices will be set.
      if (tf_node->has_assigned_device_name()) {
        device_maps[u.ParentValue()->name()].insert(
            tf_node->assigned_device_name());
      } else if (!tf_node->requested_device().empty()) {
        device_maps[u.ParentValue()->name()].insert(
            tf_node->requested_device());
      } else {
        VLOG(2) << "Node " << tf_node->name()
                << " has no device assigned requested device is: "
                << tf_node->requested_device();
      }
    }
  }

  // --------------------------------- Step 2 ---------------------------------
  // Remove ineligible input/output nodes.
  for (auto& itr : sg_map) {
    std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second;
    VLOG(1) << "Segment original size: " << segment_nodes.size();
    while (true) {
      std::deque<const Node*> in_nodes_que, out_nodes_que;
      // Find an input node that is not eligible and add it to the queue.
      // Nodes that has no incoming edges should not be treated as "input",
      // as there are really no inputs to them. Similar for output nodes.
      for (auto node : segment_nodes) {
        bool added = false;
        for (const Edge* edge : node->in_edges()) {
          if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
              !segment_nodes.count(edge->src())) {  // 'node' is an input node.
            if (!input_candidate_fn(edge)) {
              in_nodes_que.push_back(node);
              added = true;
              break;
            }
          }
        }
        if (added) continue;  // Only adding the node once to either queue.
        for (const Edge* edge : node->out_edges()) {
          if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
              !segment_nodes.count(edge->dst())) {  // 'node' is an output node.
            if (!output_candidate_fn(edge)) {
              out_nodes_que.push_back(node);
              break;
            }
          }
        }
      }
      if (in_nodes_que.empty() && out_nodes_que.empty()) {
        // No more ineligible input/output nodes.
        break;
      }
      // Now for each ineligible node, remove all of its inputs or outputs from
      // the subgraph.
      //
      // It can be proven that, if the original subgraph:
      // 1. is a DAG, and
      // 2. all paths between two nodes in the subgraph are all inside the
      //    subgraph
      // then after doing this operation the resulting subgraph will keep the
      // same properties 1 and 2.
      //
      // For simplicity we use heuristics: for input and const output nodes
      // remove all their inputs, and for non-const output nodes remove all
      // their outputs. In this way, for common cases the number of removed
      // nodes should be minimum.
      auto remove_nodes = [&segment_nodes](bool is_input_nodes,
                                           std::deque<const Node*>* que) {
        // Run a BFS on the queue to find all the input/output nodes.
        std::set<const Node*, NodePtrCompare> visited;
        std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
        while (!que->empty()) {
          auto node = que->front();
          que->pop_front();
          if (!visited.insert(node).second) continue;
          segment_nodes.erase(node);
          for (auto in : (is_input_nodes || node->type_string() == "Const")
                             ? node->in_nodes()
                             : node->out_nodes()) {
            if (segment_nodes.count(in)) {
              que->push_back(in);
              if (VLOG_IS_ON(2)) {
                if (!logged.count(in)) {
                  VLOG(2) << "----> Need to remove node " << in->name()
                          << " because one of its "
                          << (is_input_nodes ? "output" : "input")
                          << " nodes in the graph was removed: "
                          << node->name();
                  logged.insert(in);
                }
              }
            }
          }
        }
      };
      remove_nodes(true, &in_nodes_que);
      remove_nodes(false, &out_nodes_que);
    }
    VLOG(1) << "Segment new size: " << segment_nodes.size();
  }

  // --------------------------------- Step 3 ---------------------------------
  // Convert the segments into the expected return format
  for (const auto& itr : sg_map) {
    const string& segment_root = itr.first;
    // Return format does not require set comparator.
    std::set<const Node*> segment_nodes(itr.second.begin(), itr.second.end());
    if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
      string s;
      for (auto node : segment_nodes) {
        StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
      }
      VLOG(1) << "Nodes in segment " << segments->size()
              << " with parent=" << segment_root << ":" << s;
    }

    const int num_effective_nodes = std::count_if(
        segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
          static auto noops =
              new std::set<string>{"Identity", "Snapshot", "StopGradient"};
          return noops->count(node->type_string()) == 0;
        });

    // Don't use segments whose number of effective nodes is small.
    if (num_effective_nodes < options.minimum_segment_size) {
      VLOG(1) << "Segment " << segments->size() << " has only "
              << num_effective_nodes << " effective nodes, dropping";
      continue;
    }

    const auto& dev_itr = device_maps.find(segment_root);
    if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
      VLOG(1) << "No device assigned to segment " << segments->size();
    } else if (dev_itr->second.size() > 1) {
      string s = StrCat("Segment ", segments->size(),
                        " has multiple devices attached: ");
      for (const auto& dev : dev_itr->second) {
        StrAppend(&s, dev, ", ");
      }
      LOG(WARNING) << s;
    }

    segments->emplace_back(segment_nodes);
  }
  if (VLOG_IS_ON(1)) {
    for (const auto& d : device_maps) {
      string s("Segment ");
      StrAppend(&s, ": '", d.first, "' ");
      for (const auto& dd : d.second) {
        StrAppend(&s, dd, ", ");
      }
      VLOG(1) << "Devices " << s;
    }
  }
  return Status::OK();
}