in grappler/convert/convert_graph.cc [750:887]
Status CreateNeuronGraphDef(GraphDef* new_graph_def, const GraphDef& graph_def,
const std::vector<std::string>& input_op_names,
const std::vector<std::string>& output_op_names,
const bool fuse_foldable_nodes,
const int minimum_segment_size,
const double prune_small_subgraphs_ratio,
const std::set<std::string>& supported_op_types,
const std::set<std::string>& no_fuse_ops,
const std::set<std::string>& force_fuse_ops) {
// Segment the graph into subgraphs that can be converted to Neuron op
tensorflow::tensorrt::segment::SegmentOptions segment_options;
VLOG(1) << "Building Neuron Op\n";
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
graph_def.library());
GraphDef temp_graph_def;
temp_graph_def.CopyFrom(graph_def);
TF_RETURN_IF_ERROR(PreProcessingGraphDef(temp_graph_def));
tensorflow::Graph graph(flib);
TF_CHECK_OK(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), temp_graph_def, &graph));
// Build output tensor names
std::unordered_map<std::string, const Node*> op_name_to_node;
for (const Node* node : graph.op_nodes()) {
op_name_to_node[node->name()] = node;
}
std::vector<std::string> outputs;
for (const auto& op_name : output_op_names) {
const Node* node = op_name_to_node[op_name];
int64 num_outputs = node->num_outputs();
VLOG(1) << "Output " << op_name << " contains " << num_outputs
<< " outputs";
for (int64 idx = 0; idx < num_outputs; ++idx) {
outputs.push_back(op_name + ":" + std::to_string(idx));
}
}
// Find "constant-foldable" nodes and claim them as supported
std::unordered_set<std::string> foldable_nodes;
if (fuse_foldable_nodes) {
TF_RETURN_IF_ERROR(FindConstantFoldableNodes(&foldable_nodes, graph_def));
}
std::unordered_map<std::string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
segment_options.minimum_segment_size = minimum_segment_size;
// Setup exclude_node_list
for (Node* node : graph.nodes()) {
bool is_source_or_sink = node->IsSink() || node->IsSource();
bool is_supported = supported_op_types.count(node->type_string());
bool no_fuse = no_fuse_ops.count(node->name());
bool force_fuse = force_fuse_ops.count(node->name());
bool is_foldable = foldable_nodes.count(node->name());
bool supported_can_fuse = is_supported && !is_source_or_sink && !no_fuse;
bool fuseable = supported_can_fuse || force_fuse || is_foldable;
if (node->def().attr().count(kNeuronInFixedShapeContext)) {
bool fixed_shape = node->def().attr().at(kNeuronInFixedShapeContext).b();
VLOG(1) << "Node " << node->name() << " fixed_shape=" << fixed_shape;
fuseable &= fixed_shape;
}
if (!fuseable) {
segment_options.exclude_node_list.insert(node->name());
}
}
// All inout nodes to exclude list
for (auto node_name : input_op_names) {
segment_options.exclude_node_list.insert(node_name);
// Adding all the nodes before the input node to exclude list.
tensorflow::Node* omit_node = node_map[node_name];
if (omit_node) {
TF_RETURN_IF_ERROR(
ExcludeInputNodes(omit_node, segment_options.exclude_node_list));
}
}
tensorflow::tensorrt::segment::SegmentNodesVector segments;
std::function<bool(const Edge*)> input_edge_validator;
std::function<bool(const Edge*)> output_edge_validator;
if (force_fuse_ops.size()) {
// Don't exclude edges if manual segmentation is specified
input_edge_validator = [](const Edge* edge) { return true; };
output_edge_validator = [](const Edge* edge) { return true; };
} else {
input_edge_validator = EdgeValidator();
output_edge_validator = OutputEdgeValidator();
}
TF_RETURN_IF_ERROR(tensorflow::tensorrt::segment::SegmentGraph(
&graph, [](const Node* node) { return Status::OK(); },
input_edge_validator, output_edge_validator,
segment_options, &segments));
if (segments.size() > 1) {
VLOG(1) << "MULTIPLE Neuron candidate conversion: " << segments.size();
if (prune_small_subgraphs_ratio < 0.0 ||
prune_small_subgraphs_ratio > 1.0) {
return errors::Internal("Found invalid prune_small_subgraphs_ratio ",
prune_small_subgraphs_ratio);
}
if (prune_small_subgraphs_ratio > 0.0) {
size_t size_all_segments = 0;
for (const auto& seg : segments) {
size_all_segments += seg.size();
}
VLOG(1) << "Total size of all segments: " << size_all_segments;
auto comp = [](const std::set<const Node*>& lhs,
const std::set<const Node*>& rhs) {
return lhs.size() < rhs.size();
};
auto max_segment =
*std::max_element(segments.begin(), segments.end(), comp);
VLOG(1) << "Maximum segment size " << max_segment.size();
if (((double)max_segment.size() / (double)size_all_segments) >
prune_small_subgraphs_ratio) {
VLOG(1) << "Only keep maximum segment with size " << max_segment.size();
segments.clear();
segments.push_back(max_segment);
}
}
}
if (segments.size()) {
PreProcessSegmentsForResources(graph, segments);
TF_RETURN_IF_ERROR(ProcessSegments(graph, outputs, node_map, segments));
}
graph.ToGraphDef(new_graph_def);
VLOG(2) << "new_graph_def: " << new_graph_def->DebugString();
return tensorflow::Status::OK();
}