Status PropagateShapes()

in grappler/shape_inference.cc [47:329]


Status PropagateShapes(Graph* graph,
                       const std::map<int, InferredShape>& arg_shapes,
                       const std::vector<BackEdgeHelper::BackEdge>& back_edges,
                       ShapeRefiner* shape_refiner) {
  std::map<const Node*, const Node*> merge_to_next_iteration;
  for (const auto& e : back_edges) {
    if (e.src->IsNextIteration() && e.dst->IsMerge()) {
      merge_to_next_iteration[e.dst] = e.src;
    }
  }

  // Visits the nodes in topological order (reverse post-order), inferring
  // shapes.
  // TODO(phawkins): handle cyclic graphs.
  std::vector<Node*> order;
  GetReversePostOrder(*graph, &order);

  std::unordered_map<std::string, int64> resolved_ints;
  std::unordered_map<std::string, TensorShape> resolved_shapes;
  typedef gtl::InlinedVector<int64, 4> GtlInt64Vector;
  std::unordered_map<std::string, GtlInt64Vector> resolved_vectors;
  std::unordered_map<std::string, int64> resolved_tensor_array_sizes;
  std::unordered_map<std::string, int64> resolved_range_sizes;
  for (Node* n : order) {
    // Ignore the status returned by the shape_refiner. We want the best effort
    // shapes, even if no shape function is registered for a node.
    Status status = shape_refiner->AddNode(n);
    if (!status.ok()) {
      VLOG(1) << "Shape inference failed for node " << n->name() << ": "
              << status;
    } else {
      shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
      for (int i = 0; i < n->num_outputs(); i++) {
        shape_inference::ShapeHandle handle = context->output(i);
        VLOG(4) << "Output " << i << " for node " << n->name() << ": "
                << context->DebugString(handle);
        auto& attr = n->def().attr();
        if (attr.count(kNeuronInferredShapes)) {
          auto& shape_list = attr.at(kNeuronInferredShapes).list().shape();
          if (i < shape_list.size()) {
            PartialTensorShape shape(shape_list[i]);
            TF_RETURN_IF_ERROR(
                context->MakeShapeFromPartialTensorShape(shape, &handle));
            context->set_output(i, handle);
            if (shape.IsFullyDefined()) {
              std::string tensor_name = n->name() + ":" + std::to_string(i);
              std::string gd_tensor_name = i == 0 ? n->name() : tensor_name;
              resolved_shapes[gd_tensor_name] = shape_list[i];
              VLOG(1) << "Set fully defined shape for node " << n->name()
                      << " at output port " << i;
            }
          }
        }
      }
      if (n->type_string() == "TensorArrayGatherV3") {
        std::string input1_name = n->def().input(1);
        if (resolved_range_sizes.count(input1_name)) {
          PartialTensorShape shape(n->def().attr().at("element_shape").shape());
          if (shape.IsFullyDefined()) {
            shape.InsertDim(0, resolved_range_sizes[input1_name]);
            shape_inference::ShapeHandle handle = context->output(0);
            TF_RETURN_IF_ERROR(
                context->MakeShapeFromPartialTensorShape(shape, &handle));
            context->set_output(0, handle);
            VLOG(1) << "Inferred fully defined shape of " << n->name()
                    << " using TensorArray operator shape inference mechanism";
          }
        }
      }
    }

    if (n->type_string() == "Const") {
      const auto& tensor = n->def().attr().at("value").tensor();
      if (tensor.dtype() == DT_INT32 && tensor.int_val_size() == 1) {
        int64 int_val = tensor.int_val(0);
        resolved_ints[n->name()] = int_val;
        VLOG(2) << "filled resolved_ints[" << n->name() << "] with " << int_val;
      }
    }

    if (n->type_string() == "Shape") {
      std::string input0_name = n->def().input(0);
      if (resolved_shapes.count(input0_name)) {
        auto& shape = resolved_shapes[input0_name];
        resolved_vectors[n->name()] = shape.dim_sizes();
        VLOG(2) << "filled resolved_vectors[" << n->name() << "] with vector "
                << shape;
      }
    }

    if (n->type_string() == "StridedSlice") {
      const auto& n_def = n->def();
      const auto& attr = n_def.attr();
      if (attr.at("Index").type() == DT_INT32 &&
          attr.at("T").type() == DT_INT32 && attr.at("begin_mask").i() == 0 &&
          attr.at("ellipsis_mask").i() == 0 && attr.at("end_mask").i() == 0 &&
          attr.at("new_axis_mask").i() == 0 &&
          attr.at("shrink_axis_mask").i() == 1 &&
          resolved_vectors.count(n_def.input(0)) &&
          resolved_ints.count(n_def.input(1)) &&
          resolved_ints.count(n_def.input(2)) &&
          resolved_ints.count(n_def.input(3))) {
        int64 start = resolved_ints[n_def.input(1)];
        int64 end = resolved_ints[n_def.input(2)];
        int64 step = resolved_ints[n_def.input(3)];
        auto& vector = resolved_vectors[n_def.input(0)];
        if (end - start == 1 && step == 1 && start < (int64)vector.size()) {
          int64 int_val = vector[start];
          resolved_ints[n->name()] = int_val;
          VLOG(2) << "filled resolved_ints[" << n->name() << "] with "
                  << int_val;
        }
      }
    }

    if (n->type_string() == "TensorArrayV3") {
      const auto& n_def = n->def();
      const auto& attr = n_def.attr();
      if (!attr.at("dynamic_size").b() && resolved_ints.count(n_def.input(0))) {
        int64 int_val = resolved_ints[n_def.input(0)];
        resolved_tensor_array_sizes[n->name()] = int_val;
        VLOG(2) << "filled resolved_tensor_array_sizes[" << n->name()
                << "] with " << int_val;
      }
    }

    if (n->type_string() == "TensorArraySizeV3") {
      const auto& n_def = n->def();
      if (resolved_tensor_array_sizes.count(n_def.input(0))) {
        int64 int_val = resolved_tensor_array_sizes[n_def.input(0)];
        resolved_tensor_array_sizes[n->name()] = int_val;
        VLOG(2) << "filled resolved_tensor_array_sizes[" << n->name()
                << "] with " << int_val;
      }
    }

    if (n->type_string() == "Range") {
      const auto& n_def = n->def();
      const auto& attr = n_def.attr();
      if (attr.at("Tidx").type() == DT_INT32 &&
          resolved_ints.count(n_def.input(0)) &&
          resolved_tensor_array_sizes.count(n_def.input(1)) &&
          resolved_ints.count(n_def.input(2))) {
        int64 start = resolved_ints[n_def.input(0)];
        int64 end = resolved_tensor_array_sizes[n_def.input(1)];
        int64 step = resolved_ints[n_def.input(2)];
        int64 diff = end - start;
        int64 num_elements = diff / step + diff % step;
        resolved_range_sizes[n->name()] = num_elements;
        VLOG(2) << "filled resolved_range_sizes[" << n->name() << "] with "
                << num_elements;
      }
    }

    if (n->type_string() == "_Arg") {
      int index;
      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
      auto it = arg_shapes.find(index);
      if (it != arg_shapes.end()) {
        const InferredShape& arg_shape = it->second;
        shape_inference::InferenceContext* context =
            shape_refiner->GetContext(n);

        if (arg_shape.handle_type != DT_INVALID) {
          shape_inference::ShapeHandle handle;
          TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
              arg_shape.handle_shape, &handle));

          // Sets the shape and type of the variable's value.
          context->set_output_handle_shapes_and_types(
              0, std::vector<shape_inference::ShapeAndType>{
                     {handle, arg_shape.handle_type}});
        }

        shape_inference::ShapeHandle handle;
        TF_RETURN_IF_ERROR(
            context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
        TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
      }
    }

    // Sometimes we have VariableShape nodes in while loop (after Enter nodes).
    // They won't be constant-folded because TensorFlow constant folding does
    // not handle Enter nodes (and thus does not handle any nodes after Enter
    // nodes). We try to replace such VariableShape nodes with Const nodes here.
    if (n->type_string() == "VariableShape") {
      shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
      auto handle_shapes_and_types = context->input_handle_shapes_and_types(0);
      if (handle_shapes_and_types && !handle_shapes_and_types->empty()) {
        shape_inference::ShapeHandle handle =
            handle_shapes_and_types->at(0).shape;
        TensorShapeProto shape_proto;
        context->ShapeHandleToProto(handle, &shape_proto);
        if (!shape_proto.unknown_rank()) {
          NodeDef const_def;
          const_def.set_op("Const");
          Node* var_node;
          TF_RETURN_IF_ERROR(n->input_node(0, &var_node));
          const_def.set_name(
              graph->NewName(absl::StrCat("var_shape_", var_node->name())));
          DataType dtype = n->output_type(0);
          AddNodeAttr("dtype", dtype, &const_def);
          TensorProto value;
          value.set_dtype(dtype);
          value.mutable_tensor_shape()->add_dim()->set_size(
              shape_proto.dim_size());
          for (const auto& dim : shape_proto.dim()) {
            if (dtype == DT_INT32) {
              value.add_int_val(dim.size());
            } else {
              value.add_int64_val(dim.size());
            }
          }
          AddNodeAttr("value", value, &const_def);
          for (auto const& attr : n->attrs()) {
            if (*attr.first.begin() == '_') {
              AddNodeAttr(attr.first, attr.second, &const_def);
            }
          }

          Status s;
          Node* const_node = graph->AddNode(const_def, &s);
          TF_RETURN_IF_ERROR(s);

          graph->AddControlEdge(var_node, const_node);
          std::vector<const Edge*> out_edges(n->out_edges().begin(),
                                             n->out_edges().end());
          for (const Edge* e : out_edges) {
            if (e->IsControlEdge()) {
              graph->AddControlEdge(const_node, e->dst());
              graph->RemoveEdge(e);
            } else {
              Node* dst = e->dst();
              int dst_input = e->dst_input();
              graph->RemoveEdge(e);
              graph->AddEdge(const_node, 0, dst, dst_input);
            }
          }
        }
      }
    }

    // Merge node causes a loop so we remove NextIteration->Merge edge before
    // performing shape inference. But removing those edges also prevents us
    // from inferring output shape for Merge node (we need shapes for all its
    // inputs).
    // For loop invariant resource input's Merge node, we set output resource
    // shape as Enter node's resource shape.
    // TODO(b/129367850): clean this up.
    if (n->IsMerge() && n->output_type(0) == DT_RESOURCE) {
      // Check if this is a loop invariant input's Merge node. We do it by
      // checking if corresponding NextIteration node comes from Switch node
      // directly.
      auto iter = merge_to_next_iteration.find(n);
      if (iter != merge_to_next_iteration.end()) {
        const Node *next_iter = iter->second, *node = next_iter;
        do {
          TF_RETURN_IF_ERROR(node->input_node(0, &node));
        } while (node->IsIdentity());
        const Node* switch_input;
        bool is_loop_invariant = node->IsSwitch() &&
                                 node->input_node(0, &switch_input).ok() &&
                                 switch_input == n;
        if (is_loop_invariant) {
          shape_inference::InferenceContext* context =
              shape_refiner->GetContext(n);
          for (int i = 0; i < n->num_inputs(); i++) {
            const Node* input_node;
            if (n->input_node(i, &input_node).ok()) {
              auto shapes_and_types = context->input_handle_shapes_and_types(i);
              if (shapes_and_types) {
                context->set_output_handle_shapes_and_types(0,
                                                            *shapes_and_types);
              }
              break;
            }
          }
        }
      }
    }
  }
  return Status::OK();
}