Status CreateNeuronGraphDef()

in grappler/convert/convert_graph.cc [750:887]


Status CreateNeuronGraphDef(GraphDef* new_graph_def, const GraphDef& graph_def,
                            const std::vector<std::string>& input_op_names,
                            const std::vector<std::string>& output_op_names,
                            const bool fuse_foldable_nodes,
                            const int minimum_segment_size,
                            const double prune_small_subgraphs_ratio,
                            const std::set<std::string>& supported_op_types,
                            const std::set<std::string>& no_fuse_ops,
                            const std::set<std::string>& force_fuse_ops) {
  // Segment the graph into subgraphs that can be converted to Neuron op
  tensorflow::tensorrt::segment::SegmentOptions segment_options;

  VLOG(1) << "Building Neuron Op\n";

  tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
                                             graph_def.library());

  GraphDef temp_graph_def;
  temp_graph_def.CopyFrom(graph_def);
  TF_RETURN_IF_ERROR(PreProcessingGraphDef(temp_graph_def));

  tensorflow::Graph graph(flib);

  TF_CHECK_OK(tensorflow::ConvertGraphDefToGraph(
      tensorflow::GraphConstructorOptions(), temp_graph_def, &graph));

  // Build output tensor names
  std::unordered_map<std::string, const Node*> op_name_to_node;
  for (const Node* node : graph.op_nodes()) {
    op_name_to_node[node->name()] = node;
  }
  std::vector<std::string> outputs;
  for (const auto& op_name : output_op_names) {
    const Node* node = op_name_to_node[op_name];
    int64 num_outputs = node->num_outputs();
    VLOG(1) << "Output " << op_name << " contains " << num_outputs
            << " outputs";
    for (int64 idx = 0; idx < num_outputs; ++idx) {
      outputs.push_back(op_name + ":" + std::to_string(idx));
    }
  }

  // Find "constant-foldable" nodes and claim them as supported
  std::unordered_set<std::string> foldable_nodes;
  if (fuse_foldable_nodes) {
    TF_RETURN_IF_ERROR(FindConstantFoldableNodes(&foldable_nodes, graph_def));
  }

  std::unordered_map<std::string, tensorflow::Node*> node_map;
  TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));

  segment_options.minimum_segment_size = minimum_segment_size;

  // Setup exclude_node_list
  for (Node* node : graph.nodes()) {
    bool is_source_or_sink = node->IsSink() || node->IsSource();
    bool is_supported = supported_op_types.count(node->type_string());
    bool no_fuse = no_fuse_ops.count(node->name());
    bool force_fuse = force_fuse_ops.count(node->name());
    bool is_foldable = foldable_nodes.count(node->name());
    bool supported_can_fuse = is_supported && !is_source_or_sink && !no_fuse;
    bool fuseable = supported_can_fuse || force_fuse || is_foldable;
    if (node->def().attr().count(kNeuronInFixedShapeContext)) {
      bool fixed_shape = node->def().attr().at(kNeuronInFixedShapeContext).b();
      VLOG(1) << "Node " << node->name() << " fixed_shape=" << fixed_shape;
      fuseable &= fixed_shape;
    }
    if (!fuseable) {
      segment_options.exclude_node_list.insert(node->name());
    }
  }

  // All inout nodes to exclude list
  for (auto node_name : input_op_names) {
    segment_options.exclude_node_list.insert(node_name);

    // Adding all the nodes before the input node to exclude list.
    tensorflow::Node* omit_node = node_map[node_name];
    if (omit_node) {
      TF_RETURN_IF_ERROR(
          ExcludeInputNodes(omit_node, segment_options.exclude_node_list));
    }
  }

  tensorflow::tensorrt::segment::SegmentNodesVector segments;
  std::function<bool(const Edge*)> input_edge_validator;
  std::function<bool(const Edge*)> output_edge_validator;
  if (force_fuse_ops.size()) {
    // Don't exclude edges if manual segmentation is specified
    input_edge_validator = [](const Edge* edge) { return true; };
    output_edge_validator = [](const Edge* edge) { return true; };
  } else {
    input_edge_validator = EdgeValidator();
    output_edge_validator = OutputEdgeValidator();
  }

  TF_RETURN_IF_ERROR(tensorflow::tensorrt::segment::SegmentGraph(
      &graph, [](const Node* node) { return Status::OK(); },
      input_edge_validator, output_edge_validator,
      segment_options, &segments));
  if (segments.size() > 1) {
    VLOG(1) << "MULTIPLE Neuron candidate conversion: " << segments.size();
    if (prune_small_subgraphs_ratio < 0.0 ||
        prune_small_subgraphs_ratio > 1.0) {
      return errors::Internal("Found invalid prune_small_subgraphs_ratio ",
                              prune_small_subgraphs_ratio);
    }
    if (prune_small_subgraphs_ratio > 0.0) {
      size_t size_all_segments = 0;
      for (const auto& seg : segments) {
        size_all_segments += seg.size();
      }
      VLOG(1) << "Total size of all segments: " << size_all_segments;
      auto comp = [](const std::set<const Node*>& lhs,
                     const std::set<const Node*>& rhs) {
        return lhs.size() < rhs.size();
      };
      auto max_segment =
          *std::max_element(segments.begin(), segments.end(), comp);
      VLOG(1) << "Maximum segment size " << max_segment.size();
      if (((double)max_segment.size() / (double)size_all_segments) >
          prune_small_subgraphs_ratio) {
        VLOG(1) << "Only keep maximum segment with size " << max_segment.size();
        segments.clear();
        segments.push_back(max_segment);
      }
    }
  }

  if (segments.size()) {
    PreProcessSegmentsForResources(graph, segments);
    TF_RETURN_IF_ERROR(ProcessSegments(graph, outputs, node_map, segments));
  }

  graph.ToGraphDef(new_graph_def);
  VLOG(2) << "new_graph_def: " << new_graph_def->DebugString();
  return tensorflow::Status::OK();
}