in tensorflow/tensorflow/compiler/tf2tensorrt/segment/segment.cc [734:1132]
Status SegmentGraph(const Graph* tf_graph,
const grappler::GraphProperties* graph_properties,
const std::function<Status(const Node*, const std::unordered_set<string> &target_nodes)>& candidate_fn,
const std::function<bool(const Edge*)>& input_candidate_fn,
const std::function<bool(const Edge*)>& output_candidate_fn,
const SegmentOptions& options, SegmentVector* segments) {
if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
return errors::Internal(
"Explicit batch mode should allow dynamic non-batch dimensions");
}
if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
return errors::Internal("Implicit batch mode requires maximum_batch_size");
}
if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
return errors::Internal(
"Need graph propertities to disallow dynamic non-batch dimensions");
}
// Steps:
// 1. run the segmentation algorithm to find all the segments, which uses
// candidate_fn to determine the candidates segment nodes;
// 2. for each segments, remove the nodes that are inputs/outputs of the
// segment but are not eligible, using input/output_candidate_fn to
// determine the eligibilities;
// 3. convert the segment into expected return format and return the result.
// get target convert nodes
std::unordered_set<string> target_nodes;
SearchNodesWithRanges(tf_graph, options.convert_ranges, target_nodes);
// --------------------------------- Step 1 ---------------------------------
auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
// Use a union-find to collect the nodes that belong to the same
// segment. A node value of nullptr indicates that the node is not a candidate
// for TRT.
std::map<string, int> unsupported_ops_map = {};
// Getting the operations denylisted for conversion
string tftrt_op_denylist_str;
TF_CHECK_OK(
ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str));
auto tftrt_op_denylist = gtl::FlatSet<string>{}; // non-absl ok
for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
tftrt_op_denylist.insert(x);
}
// Parsing each node of the graph
std::vector<UnionFind<SimpleNode*>> node_segments;
for (int i = 0; i < graph->num_node_ids(); ++i) {
SimpleNode* node = graph->FindNodeId(i);
if (!node) {
VLOG(3) << "Node " << i << " doesn't exist in the graph";
continue;
}
auto exclude_node = [&](absl::string_view reason) {
LOG(INFO) << "Not a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << "), "
<< "(Reason: " << reason << ")";
unsupported_ops_map[node->tf_node()->type_string()]++;
node = nullptr;
};
absl::optional<DeviceNameUtils::ParsedName> device_name =
GetDeviceParsedName(node->tf_node());
// GetDeviceParseName capitalizes the device type.
if (!device_name.has_value() ||
(device_name->has_type && device_name->type != "GPU")) {
exclude_node("node can't be placed on GPU");
} else if (options.exclude_node_list.count(node->name()) != 0) {
exclude_node("excluded by segmenter option");
} else if (options.use_implicit_batch &&
!OperationCanBeTranslatedToImplicitBatch(graph_properties,
node->tf_node())) {
exclude_node(
"implicit batch mode requires input shape with at least two "
"dimensions");
} else if (!options.allow_dynamic_non_batch_dim &&
OperationHasDynamicNonBatchDimension(graph_properties,
node->tf_node())) {
exclude_node("dynamic non-batch dimensions not allowed");
} else {
const Status status = candidate_fn(node->tf_node(), target_nodes);
if (!status.ok()) {
exclude_node(status.error_message());
} else if (tftrt_op_denylist.count(node->tf_node()->type_string())) {
// WARNING verbosity since the user explicitly requests this behavior.
LOG_WARNING_WITH_PREFIX
<< "Denylisted as TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name() << ")";
exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
} else {
VLOG(2) << "Accepted as a TF-TRT candidate, "
<< "(Op type: " << node->tf_node()->type_string() << "), "
<< "(Op name: " << node->name();
}
}
AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
options.use_implicit_batch);
}
string unsupported_op_report =
StrCat("\n", string(80, '#'), "\n",
"TensorRT unsupported/unconverted OP Report:");
int total_unconverted_ops{0};
// Copy key-value pair from unsupported_ops_map to vector of pairs
std::vector<std::pair<std::string, int>> _vect;
for (auto& _it : unsupported_ops_map) {
_vect.push_back(_it);
}
// Sort in descending order using the number of uses of the OP that are not
// converted.
std::sort(_vect.begin(), _vect.end(),
[](const std::pair<std::string, int>& _a,
const std::pair<std::string, int>& _b) -> bool {
return _a.second > _b.second;
});
for (auto& _it : _vect) {
unsupported_op_report = StrCat(unsupported_op_report, "\n\t- ", _it.first,
" -> ", _it.second, "x");
total_unconverted_ops += _it.second;
}
unsupported_op_report =
StrCat(unsupported_op_report, "\n", string(80, '-'),
"\n\t - Total unconverted OPs: ", total_unconverted_ops,
"\n\t - Total unconverted OP Types: ", unsupported_ops_map.size(),
"\nFor more information see https://docs.nvidia.com/deeplearning",
"/frameworks/tf-trt-user-guide/index.html#supported-ops.", "\n",
string(80, '#'));
LOG(INFO) << unsupported_op_report;
// The segmentation algorithm below visits nodes in reverse topological order
// and attempts to merge nodes along output edges. That means that subgraphs
// grow from the output-side of the network towards the inputs.
//
// In general this is not guaranteed to produce a globally optimal
// segmentation. For example, consider graph with node {A, B, C, D} and edges
// {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
// in theory we can choose to contract either A, B or B, D but not both, but
// here it always choose to contract B, D.
//
// In the future if we have a measure of how beneficial it is to include a
// given node in a TRT subgraph then we can revisit this algorithm to take
// advantage of that information.
std::vector<const SimpleNode*> order;
order.reserve(graph->num_node_ids());
StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
/*enter=*/nullptr, [&order](const SimpleNode* n) {
order.push_back(n);
return true;
});
for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited.
VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
// 'node' must be a TRT candidate.
if (node_segments[node->id()].Value() == nullptr) {
VLOG(3) << "... not a TRT candidate";
continue;
}
// Contract output edges to combine 'node' with output nodes. Repeat this
// step until no output edges can be further contracted. This is because
// contracting an output edge may unblock new edges for contracting.
ClusterBatchSize expected_batch_size =
node_segments[node->id()].Property().BatchSize();
DeviceNameUtils::ParsedName expected_device_name =
node_segments[node->id()].Property().DeviceName();
VLOG(3) << "batch size " << expected_batch_size;
while (true) {
std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
// TODO(bixia): consider merging the loop to find the edges and the loop
// to contract the edges.
for (const SimpleEdge* out_edge : node->out_edges()) {
VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
VLOG(3) << "... ... Control Edge, Skipping";
continue;
}
UnionFind<SimpleNode*>* out_cluster =
&node_segments[out_edge->dst()->id()];
// Out node must be a TRT candidate.
if (out_cluster->Value() == nullptr) {
VLOG(3) << "... ... not a TRT candidate";
continue;
}
// Out node must have compatible batch size.
ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
ClusterBatchSize merged_batch_size = expected_batch_size;
if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
VLOG(3) << "... ... incompatible batch sizes "
<< expected_batch_size.ToString() << " "
<< out_batch_size.ToString();
continue;
}
const DeviceNameUtils::ParsedName& out_device_name =
out_cluster->Property().DeviceName();
absl::optional<DeviceNameUtils::ParsedName> merged_device_name =
MergeIfCompatible(expected_device_name, out_device_name);
if (!merged_device_name.has_value()) {
VLOG(3) << "... ... incompatible device names "
<< expected_device_name << " " << out_device_name;
continue;
}
if (CanContractEdge(out_edge, graph)) {
VLOG(3) << "... ... can contract. new batch size "
<< merged_batch_size.ToString();
contract_edges.insert(out_edge);
expected_batch_size = merged_batch_size;
expected_device_name = *merged_device_name;
} else {
VLOG(3) << "... ... cannot contract, would form cycle";
}
}
if (contract_edges.empty()) {
break;
}
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
const SimpleEdge* contract_edge = *contract_edges.begin();
const SimpleNode* src = contract_edge->src();
const SimpleNode* dst = contract_edge->dst();
VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
TF_RETURN_IF_ERROR(
node_segments[src->id()].Merge(&node_segments[dst->id()]));
// Contracting the edge leaves disconnected graph edges.
// Remove these from the graph and from 'contract_edges' so we
// don't visit them again.
SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
std::vector<const SimpleEdge*> remove_edges;
ContractEdge(e, graph.get(), &remove_edges);
for (const SimpleEdge* r : remove_edges) {
contract_edges.erase(r);
graph->RemoveEdge(r);
}
}
if (expected_batch_size !=
node_segments[node->id()].Property().BatchSize()) {
return errors::Internal(
"expected batch size is not the same as the actual batch size");
}
if (!(expected_device_name ==
node_segments[node->id()].Property().DeviceName())) {
return errors::Internal(
"expected device name is not the same as the actual device name");
}
}
}
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
std::map<string, Segment> sg_map;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
}
if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
sg_map[u.Value()->name()].property = u.Property();
}
}
// --------------------------------- Step 2 ---------------------------------
// Remove ineligible input/output nodes.
for (auto& itr : sg_map) {
std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
VLOG(1) << "Segment original size: " << segment_nodes.size();
while (true) {
std::deque<const Node*> in_nodes_que, out_nodes_que;
// Find an input node that is not eligible and add it to the queue.
// Nodes that has no incoming edges should not be treated as "input",
// as there are really no inputs to them. Similar for output nodes.
for (auto node : segment_nodes) {
bool added = false;
for (const Edge* edge : node->in_edges()) {
if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
!segment_nodes.count(edge->src())) { // 'node' is an input node.
if (!input_candidate_fn(edge)) {
in_nodes_que.push_back(node);
added = true;
break;
}
}
}
if (added) continue; // Only adding the node once to either queue.
for (const Edge* edge : node->out_edges()) {
if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
!segment_nodes.count(edge->dst())) { // 'node' is an output node.
if (!output_candidate_fn(edge)) {
out_nodes_que.push_back(node);
break;
}
}
}
}
if (in_nodes_que.empty() && out_nodes_que.empty()) {
// No more ineligible input/output nodes.
break;
}
// Now for each ineligible node, remove all of its inputs or outputs from
// the subgraph.
//
// It can be proven that, if the original subgraph:
// 1. is a DAG, and
// 2. all paths between two nodes in the subgraph are all inside the
// subgraph
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
// For simplicity we use heuristics: for input and const output nodes
// remove all their inputs, and for non-const output nodes remove all
// their outputs. In this way, for common cases the number of removed
// nodes should be minimum.
auto remove_nodes = [&segment_nodes](bool is_input_nodes,
std::deque<const Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const Node*, NodePtrCompare> visited;
std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
for (auto in : (is_input_nodes || node->type_string() == "Const")
? node->in_nodes()
: node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
if (VLOG_IS_ON(2)) {
if (!logged.count(in)) {
VLOG(2) << "----> Need to remove node " << in->name()
<< " because one of its "
<< (is_input_nodes ? "output" : "input")
<< " nodes in the graph was removed: "
<< node->name();
logged.insert(in);
}
}
}
}
}
};
remove_nodes(true, &in_nodes_que);
remove_nodes(false, &out_nodes_que);
}
VLOG(1) << "Segment new size: " << segment_nodes.size();
}
// --------------------------------- Step 3 ---------------------------------
// Convert the segments into the expected return format
for (const auto& itr : sg_map) {
const string& segment_root = itr.first;
// Return format does not require set comparator.
std::set<const Node*, NodePtrCompare> segment_nodes(
itr.second.nodes.begin(), itr.second.nodes.end());
if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
string s;
for (auto node : segment_nodes) {
StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
}
VLOG(1) << "Nodes in segment " << segments->size()
<< " with parent=" << segment_root << ":" << s;
}
const int num_effective_nodes = std::count_if(
segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
static auto noops =
new std::set<string>{"Identity", "Snapshot", "StopGradient"};
return noops->count(node->type_string()) == 0;
});
// Don't use segments whose number of effective nodes is small.
if (num_effective_nodes == 0 ||
num_effective_nodes < options.minimum_segment_size) {
LOG(INFO) << "Segment " << segments->size() << " has only "
<< num_effective_nodes << " effective nodes, dropping";
continue;
}
segments->emplace_back(itr.second.property, segment_nodes);
}
return Status::OK();
}