in grappler/convert/segment.cc [400:720]
Status SegmentGraph(const Graph* tf_graph,
const std::function<Status(const Node*)>& candidate_fn,
const std::function<bool(const Edge*)>& input_candidate_fn,
const std::function<bool(const Edge*)>& output_candidate_fn,
const SegmentOptions& options,
SegmentNodesVector* segments) {
// Steps:
// 1. run the segmentation algorithm to find all the segments, which uses
// candidate_fn to determine the candidates segment nodes;
// 2. for each segments, remove the nodes that are inputs/outputs of the
// segment but are not eligible, using input/output_candidate_fn to
// determine the eligibilities;
// 3. convert the segment into expected return format and return the result.
// --------------------------------- Step 1 ---------------------------------
auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
// Use a union-find to collect the nodes that belong to the same
// segment. A node value of nullptr indicates that the node is not a candidate
// for TRT.
std::unordered_set<string> unsupported_ops;
int num_unsupported_ops = 0;
std::vector<UnionFind<SimpleNode*>> node_segments;
for (int i = 0; i < graph->num_node_ids(); ++i) {
SimpleNode* node = graph->FindNodeId(i);
if (options.exclude_node_list.count(node->name()) != 0) {
VLOG(1) << "Not a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << "), "
<< "(Reason: excluded by segmenter option)";
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else {
const Status status = candidate_fn(node->tf_node());
if (!status.ok()) {
VLOG(1) << "Not a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << "), "
<< "(Reason: " << status << ")";
unsupported_ops.emplace(node->tf_node()->type_string());
num_unsupported_ops++;
node = nullptr;
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name();
}
}
node_segments.emplace_back(node);
}
string msg = StrCat(
"There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(),
" different types in the graph that", " are not compiled by neuron-cc: ");
for (const auto& elem : unsupported_ops) {
StrAppend(&msg, elem, ", ");
}
LOG(INFO) << msg << "(For more information see "
<< "https://awsdocs-neuron.readthedocs-hosted.com"
<< "/en/latest/release-notes/neuron-cc-ops/"
"neuron-cc-ops-tensorflow.html).";
// The segmentation algorithm below visits nodes in reverse topological order
// and attempts to merge nodes along output edges. That means that subgraphs
// grow from the output-side of the network towards the inputs.
//
// In general this is not guaranteed to produce a globally optimal
// segmentation. For exaample, consider graph with node {A, B, C, D} and edges
// {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
// in theory we can choose to contract either A, B or B, D but not both, but
// here it always choose to contract B, D.
//
// In the future if we have a measure of how beneficial it is to include a
// given node in a TRT subgraph then we can revisit this algorithm to take
// advantage of that information.
std::vector<const SimpleNode*> order;
order.reserve(graph->num_node_ids());
StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
/*enter=*/nullptr, [&order](const SimpleNode* n) {
order.push_back(n);
return true;
});
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate...
if (node_segments[node->id()].Value() == nullptr) {
VLOG(3) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output
// nodes. Iterate since combining two nodes may unblock other
// combining.
while (true) {
std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
for (const SimpleEdge* out_edge : node->out_edges()) {
VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
VLOG(3) << "... ... Control Edge, Skipping";
continue;
}
// Out node must be TRT candidate...
if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
VLOG(3) << "... ... not a TRT candidate";
continue;
}
if (CanContractEdge(out_edge, graph)) {
VLOG(3) << "... ... can contract";
contract_edges.insert(out_edge);
} else {
VLOG(3) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
break;
}
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
const SimpleEdge* contract_edge = *contract_edges.begin();
const SimpleNode* src = contract_edge->src();
const SimpleNode* dst = contract_edge->dst();
VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
node_segments[src->id()].Merge(&node_segments[dst->id()]);
// Contracting the edge leaves disconnected graph edges.
// Remove these from the graph and from 'contract_edges' so we
// don't visit them again.
SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
std::vector<const SimpleEdge*> remove_edges;
ContractEdge(e, graph.get(), &remove_edges);
for (const SimpleEdge* r : remove_edges) {
contract_edges.erase(r);
graph->RemoveEdge(r);
}
}
}
}
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
std::map<string, std::set<const Node*, NodePtrCompare>> sg_map;
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the device names that the nodes in the segment are
// assigned to.
//
// TODO(aaroey): nodes assigned to different devices should not be merged,
// fix this.
std::unordered_map<string, std::set<string>> device_maps;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node());
auto tf_node = u.Value()->tf_node();
// has_assigned_device_name() is expected to return true
// when called from optimization pass. However, since graph
// is converted back and forth between graph and graphdef,
// assigned devices demoted to requested devices. If the graph
// is passed directly to this module, assigned devices will be set.
if (tf_node->has_assigned_device_name()) {
device_maps[u.ParentValue()->name()].insert(
tf_node->assigned_device_name());
} else if (!tf_node->requested_device().empty()) {
device_maps[u.ParentValue()->name()].insert(
tf_node->requested_device());
} else {
VLOG(2) << "Node " << tf_node->name()
<< " has no device assigned requested device is: "
<< tf_node->requested_device();
}
}
}
// --------------------------------- Step 2 ---------------------------------
// Remove ineligible input/output nodes.
for (auto& itr : sg_map) {
std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second;
VLOG(1) << "Segment original size: " << segment_nodes.size();
while (true) {
std::deque<const Node*> in_nodes_que, out_nodes_que;
// Find an input node that is not eligible and add it to the queue.
// Nodes that has no incoming edges should not be treated as "input",
// as there are really no inputs to them. Similar for output nodes.
for (auto node : segment_nodes) {
bool added = false;
for (const Edge* edge : node->in_edges()) {
if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
!segment_nodes.count(edge->src())) { // 'node' is an input node.
if (!input_candidate_fn(edge)) {
in_nodes_que.push_back(node);
added = true;
break;
}
}
}
if (added) continue; // Only adding the node once to either queue.
for (const Edge* edge : node->out_edges()) {
if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
!segment_nodes.count(edge->dst())) { // 'node' is an output node.
if (!output_candidate_fn(edge)) {
out_nodes_que.push_back(node);
break;
}
}
}
}
if (in_nodes_que.empty() && out_nodes_que.empty()) {
// No more ineligible input/output nodes.
break;
}
// Now for each ineligible node, remove all of its inputs or outputs from
// the subgraph.
//
// It can be proven that, if the original subgraph:
// 1. is a DAG, and
// 2. all paths between two nodes in the subgraph are all inside the
// subgraph
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
// For simplicity we use heuristics: for input and const output nodes
// remove all their inputs, and for non-const output nodes remove all
// their outputs. In this way, for common cases the number of removed
// nodes should be minimum.
auto remove_nodes = [&segment_nodes](bool is_input_nodes,
std::deque<const Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const Node*, NodePtrCompare> visited;
std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
for (auto in : (is_input_nodes || node->type_string() == "Const")
? node->in_nodes()
: node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
if (VLOG_IS_ON(2)) {
if (!logged.count(in)) {
VLOG(2) << "----> Need to remove node " << in->name()
<< " because one of its "
<< (is_input_nodes ? "output" : "input")
<< " nodes in the graph was removed: "
<< node->name();
logged.insert(in);
}
}
}
}
}
};
remove_nodes(true, &in_nodes_que);
remove_nodes(false, &out_nodes_que);
}
VLOG(1) << "Segment new size: " << segment_nodes.size();
}
// --------------------------------- Step 3 ---------------------------------
// Convert the segments into the expected return format
for (const auto& itr : sg_map) {
const string& segment_root = itr.first;
// Return format does not require set comparator.
std::set<const Node*> segment_nodes(itr.second.begin(), itr.second.end());
if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
string s;
for (auto node : segment_nodes) {
StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
}
VLOG(1) << "Nodes in segment " << segments->size()
<< " with parent=" << segment_root << ":" << s;
}
const int num_effective_nodes = std::count_if(
segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
static auto noops =
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
return noops->count(node->type_string()) == 0;
});
// Don't use segments whose number of effective nodes is small.
if (num_effective_nodes < options.minimum_segment_size) {
VLOG(1) << "Segment " << segments->size() << " has only "
<< num_effective_nodes << " effective nodes, dropping";
continue;
}
const auto& dev_itr = device_maps.find(segment_root);
if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
VLOG(1) << "No device assigned to segment " << segments->size();
} else if (dev_itr->second.size() > 1) {
string s = StrCat("Segment ", segments->size(),
" has multiple devices attached: ");
for (const auto& dev : dev_itr->second) {
StrAppend(&s, dev, ", ");
}
LOG(WARNING) << s;
}
segments->emplace_back(segment_nodes);
}
if (VLOG_IS_ON(1)) {
for (const auto& d : device_maps) {
string s("Segment ");
StrAppend(&s, ": '", d.first, "' ");
for (const auto& dd : d.second) {
StrAppend(&s, dd, ", ");
}
VLOG(1) << "Devices " << s;
}
}
return Status::OK();
}