Status GraphProperties::InferStatically()

in tensorflow/tensorflow/core/grappler/costs/graph_properties.cc [2130:2361]


Status GraphProperties::InferStatically(bool assume_valid_feeds,
                                        bool aggressive_shape_inference,
                                        bool include_input_tensor_values,
                                        bool include_output_tensor_values) {
  FunctionLibraryDefinition function_library(OpRegistry::Global(),
                                             item_.graph.library());
  std::unordered_map<string, std::unordered_set<int>> fed_ports;
  if (!assume_valid_feeds) {
    for (const auto& feed : item_.feed) {
      SafeTensorId tensor_id = ParseTensorName(feed.first);
      fed_ports[tensor_id.node()].insert(tensor_id.index());
    }
  }

  GraphView graph_view(&item_.graph);

  // List the resources and the nodes using them. Also collect the Merge nodes,
  // fed nodes, and primary inputs.
  std::unordered_map<const NodeDef*,
                     std::pair<std::unordered_set<const NodeDef*>,
                               std::unordered_set<const NodeDef*>>>
      resources;
  std::unordered_set<const NodeDef*> merge_nodes;
  std::unordered_set<const NodeDef*> fed_nodes;
  std::unordered_set<const NodeDef*> primary_inputs;
  int num_loops = 0;
  for (const NodeDef& node : item_.graph.node()) {
    if (IsQueue(node)) {
      for (const GraphView::InputPort& fanout :
           graph_view.GetFanouts(node, false)) {
        if (IsEnter(*fanout.node)) {
          const NodeDef& enter = *fanout.node;
          for (const GraphView::InputPort& fanout :
               graph_view.GetFanouts(enter, false)) {
            if (IsEnqueue(*fanout.node)) {
              resources[&node].first.insert(fanout.node);
            } else if (IsDequeue(*fanout.node)) {
              resources[&node].second.insert(fanout.node);
            }
          }
        } else {
          if (IsEnqueue(*fanout.node)) {
            resources[&node].first.insert(fanout.node);
          } else if (IsDequeue(*fanout.node)) {
            resources[&node].second.insert(fanout.node);
          }
        }
      }
    }
    if (!HasRegularInputs(node)) {
      primary_inputs.insert(&node);
    } else if (IsMerge(node)) {
      merge_nodes.insert(&node);
    } else if (IsNextIteration(node)) {
      ++num_loops;
    }
    if (fed_ports.find(node.name()) != fed_ports.end()) {
      fed_nodes.insert(&node);
    }
  }

  std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
  std::vector<TopologicalDependency> extra_deps;
  for (const auto& resource : resources) {
    for (const NodeDef* src : resource.second.first) {
      resource_handles[src] = resource.first;
      for (const NodeDef* dst : resource.second.second) {
        // Add control edges from enqueue to dequeue nodes to ensure they are
        // processed in their logical order.
        extra_deps.emplace_back(src, dst);
      }
    }
  }

  std::vector<const NodeDef*> topo_order;
  Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
  if (!s.ok()) {
    if (extra_deps.empty()) {
      return s;
    } else {
      // There is a loop between queues: we'll just use the graph topological
      // order. This will make the shape inference less precise but since this
      // isn't common it's not worth to figure out where to break the loop and
      // do a proper relaxation.
      TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
    }
  }

  // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
  // of stack space.
  auto refiner = absl::make_unique<SymbolicShapeRefiner>(
      graph_view, fed_ports, aggressive_shape_inference);

  TopoQueue new_shapes(topo_order);
  // Also seed the propagation of shapes in the fanout of primary inputs.
  for (const NodeDef* node : primary_inputs) {
    new_shapes.push(node);
  }
  // Also seed the propagation of shapes in the fanout of fed nodes.
  for (const NodeDef* node : fed_nodes) {
    new_shapes.push(node);
  }
  // Propagate shapes normally.
  TF_RETURN_IF_ERROR(
      PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));

  // Track shapes globally across the graph.
  std::unique_ptr<SymbolicShapeManager> shape_manager =
      absl::make_unique<SymbolicShapeManager>();
  bool found_error = false;
  for (const NodeDef& node : item_.graph.node()) {
    auto node_ctx = refiner->GetContext(&node);
    if (!node_ctx) {
      continue;
    }
    // Skip any information that comes from fed nodes.
    if (fed_ports.find(node.name()) != fed_ports.end()) {
      VLOG(2) << "Skipping feed node shape: " << node.name();
      continue;
    }
    for (const auto& merged_shapes : node_ctx->MergedShapes()) {
      if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
               .ok()) {
        found_error = true;
        break;
      }
    }
    for (const auto& merged_dims : node_ctx->MergedDims()) {
      if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
        found_error = true;
        break;
      }
    }
    if (found_error) {
      // The shapes aren't consistent, we can't infer safely: discard all the
      // information discovered so far.
      shape_manager = absl::make_unique<SymbolicShapeManager>();
      break;
    }
  }

  for (const NodeDef& node : item_.graph.node()) {
    VLOG(3) << "Filling in graph properties for node: " << node.name();
    auto ctx = refiner->GetNodeContext(&node);
    if (!ctx) {
      continue;
    }

    auto* ic = ctx->inference_context.get();

    // Fill input properties.
    {
      auto& input_properties = input_properties_[node.name()];

      // Should always be empty, node names in graph are supposed to be unique.
      CHECK_EQ(input_properties.size(), 0);

      input_properties.resize(ic->num_inputs());
      GraphView::InputPort input(&node, -1);
      for (int i = 0; i < ic->num_inputs(); ++i) {
        shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
                                          &input_properties[i]);
        input.port_id = i;
        GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
        if (include_input_tensor_values) {
          // Export tensor value to input_properties.value.
          if (IsConstant(*fanin.node)) {
            const TensorProto& raw_val =
                fanin.node->attr().at("value").tensor();
            *input_properties[i].mutable_value() = raw_val;
          } else if (ctx->input_tensor_protos.size() > i &&
                     ctx->input_tensor_protos[i] != nullptr) {
            *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
          } else if (ic->input_tensors_as_shapes().size() > i &&
                     IsShapeFullyDefinedIntegerVectorOrScalar(
                         ic, ic->input(i), ic->input_tensors_as_shapes()[i],
                         ctx->input_types[i])) {
            *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
                ic, ic->input(i), ic->input_tensors_as_shapes()[i],
                ctx->input_types[i]);
          }
        }
      }
    }

    // Fill output properties.
    {
      auto& output_properties = output_properties_[node.name()];

      // Should always be empty, node names in graph are supposed to be unique.
      CHECK_EQ(output_properties.size(), 0);

      output_properties.resize(ic->num_outputs());
      for (int i = 0; i < ic->num_outputs(); ++i) {
        shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
                                          &output_properties[i]);
        if (include_output_tensor_values) {
          // Export tensor value to output_properties.value.
          if (IsConstant(node)) {
            // TODO(rmlarsen): Eliminate this copy.
            const TensorProto& raw_val = node.attr().at("value").tensor();
            *output_properties[i].mutable_value() = raw_val;
          } else if (ctx->output_tensor_protos.size() > i &&
                     ctx->output_tensor_protos[i] != nullptr) {
            *output_properties[i].mutable_value() =
                *ctx->output_tensor_protos[i];
          } else if (ctx->output_tensors_as_shapes.size() > i &&
                     IsShapeFullyDefinedIntegerVectorOrScalar(
                         ic, ic->output(i), ctx->output_tensors_as_shapes[i],
                         ctx->output_types[i])) {
            *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
                ic, ic->output(i), ctx->output_tensors_as_shapes[i],
                ctx->output_types[i]);
          }
        }
      }
    }

    if (aggressive_shape_inference && ctx->shape_incompatible)
      incompatible_shape_nodes_.insert(node.name());
  }

  if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
    LOG(WARNING) << incompatible_shape_nodes_.size()
                 << " nodes have incompatible output shapes.";

  // Help trace the unknown dimensions to their origins.
  VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
                                    output_properties_);

  return Status::OK();
}