Status MaybeUpdateNodeContextOutput()

in tensorflow/tensorflow/core/grappler/costs/graph_properties.cc [1455:1662]


  Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
                                      NodeContext* c) {
    // Propagate tensors and shape tensors unless the node is fed.
    // TODO(bsteiner) We should still propagate the shapes to the ports that
    // aren't fed in the case of a ShapeN node.

    InferenceContext* ic = c->inference_context.get();
    if (!is_fed) {
      if (IsConstant(node)) {
        c->output_tensor_protos.resize(1);
        const TensorProto& tensor_proto = node.attr().at("value").tensor();
        c->output_tensor_protos[0] = &tensor_proto;
        c->output_tensors_as_shapes.resize(1);
        MaybeTensorProtoToShape(ic, tensor_proto,
                                &c->output_tensors_as_shapes[0]);
      } else if (IsRank(node)) {
        if (ic->RankKnown(ic->input(0))) {
          // Propagate rank value.
          int32 rank = ic->Rank(ic->input(0));
          const_tensors_to_propagate_.push_back(
              MakeIntegerScalarTensorProto(DT_INT32, rank));
          c->output_tensor_protos.resize(1);
          c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
        }
      } else if (IsSize(node)) {
        DimensionHandle size = ic->NumElements(ic->input(0));
        if (ic->ValueKnown(size)) {
          // Propagate size value.
          int64 sz = ic->Value(size);
          bool valid = false;
          if (node.attr().at("out_type").type() == DT_INT32) {
            if (sz < std::numeric_limits<int32>::max()) {
              const_tensors_to_propagate_.push_back(
                  MakeIntegerScalarTensorProto(DT_INT32, sz));
              valid = true;
            }
          } else {
            const_tensors_to_propagate_.push_back(
                MakeIntegerScalarTensorProto(DT_INT64, sz));
            valid = true;
          }
          if (valid) {
            c->output_tensor_protos.resize(1);
            c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
          }
        }
      } else if (IsShape(node)) {
        c->output_tensors_as_shapes.resize(1);
        c->output_tensors_as_shapes[0] = c->inference_context->input(0);
      } else if (IsShapeN(node)) {
        c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
        for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
          c->output_tensors_as_shapes[i] = c->inference_context->input(i);
        }
      } else if (node.op() == "ConcatV2") {
        bool valid = true;
        ShapeHandle result;
        for (int i = 0; i < ic->num_inputs() - 1; ++i) {
          ShapeHandle input = ic->input_tensors_as_shapes()[i];
          if (!ic->RankKnown(input)) {
            valid = false;
            break;
          } else if (i == 0) {
            result = input;
          } else {
            TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
          }
        }
        if (valid) {
          c->output_tensors_as_shapes.resize(1);
          c->output_tensors_as_shapes[0] = result;
        }
      } else if (IsPack(node)) {
        // A Pack node concatenating scalars is often used to generate a shape.
        std::vector<DimensionHandle> dims;
        bool valid = true;
        for (int i = 0; i < ic->num_inputs(); ++i) {
          const Tensor* t = ic->input_tensor(i);
          if (t) {
            if (t->dims() != 0 ||
                (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
              valid = false;
              break;
            }
            int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
                                                : t->scalar<int64>()();
            dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
          } else {
            // Don't have tensor value, but use input_tensors_as_shapes, if
            // possible.
            const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
            if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
                ic->ValueKnown(ic->Dim(shape_handle, 0))) {
              dims.push_back(ic->Dim(shape_handle, 0));
            } else {
              dims.push_back(ic->UnknownDim());
            }
          }
        }
        if (valid) {
          c->output_tensors_as_shapes.resize(1);
          c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
        }
      } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
        c->output_tensors_as_shapes.resize(1);
        c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
        if (c->input_tensor_protos[0] != nullptr) {
          c->output_tensor_protos.resize(1);
          c->output_tensor_protos[0] = c->input_tensor_protos[0];
        }
      } else if (IsSlice(node)) {
        ShapeHandle input = ic->input_tensors_as_shapes()[0];
        bool valid = ic->RankKnown(input);
        const Tensor* slice_offset = ic->input_tensor(1);
        valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
        const Tensor* slice_size = ic->input_tensor(2);
        valid &= slice_size != nullptr && slice_size->NumElements() == 1;
        if (valid) {
          int64 start = slice_offset->dtype() == DT_INT32
                            ? slice_offset->flat<int32>()(0)
                            : slice_offset->flat<int64>()(0);
          int64 size =
              (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
                                               : slice_size->flat<int64>()(0));
          ShapeHandle result;
          if (size == -1) {
            TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
          } else {
            int64 end = start + size;
            TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
          }
          c->output_tensors_as_shapes.resize(1);
          c->output_tensors_as_shapes[0] = result;
        }
      } else if (IsStridedSlice(node)) {
        ShapeHandle input = ic->input_tensors_as_shapes()[0];
        bool valid = ic->RankKnown(input);
        const Tensor* slice_begin = ic->input_tensor(1);
        valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
        const Tensor* slice_end = ic->input_tensor(2);
        valid &= slice_end != nullptr && slice_end->NumElements() == 1;
        const Tensor* slice_stride = ic->input_tensor(3);
        valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;

        if (node.attr().count("ellipsis_mask") > 0 &&
            node.attr().at("ellipsis_mask").i() != 0) {
          valid = false;
        }
        if (node.attr().count("new_axis_mask") > 0 &&
            node.attr().at("new_axis_mask").i() != 0) {
          valid = false;
        }
        if (node.attr().count("shrink_axis_mask") > 0 &&
            node.attr().at("shrink_axis_mask").i() != 0) {
          valid = false;
        }
        int begin_mask = 0;
        if (node.attr().count("begin_mask") > 0) {
          begin_mask = node.attr().at("begin_mask").i();
        }
        int end_mask = 0;
        if (node.attr().count("end_mask") > 0) {
          end_mask = node.attr().at("end_mask").i();
        }
        if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
          valid = false;
        }
        if (valid) {
          int64 begin = 0;
          if (begin_mask == 0) {
            begin = slice_begin->dtype() == DT_INT32
                        ? slice_begin->flat<int32>()(0)
                        : slice_begin->flat<int64>()(0);
          }
          int64 end = std::numeric_limits<int64>::max();
          if (end_mask == 0) {
            end =
                (slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
                                                : slice_end->flat<int64>()(0));
          }
          int64 stride = slice_stride->dtype() == DT_INT32
                             ? slice_stride->flat<int32>()(0)
                             : slice_stride->flat<int64>()(0);
          ShapeHandle result;
          TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
          c->output_tensors_as_shapes.resize(1);
          c->output_tensors_as_shapes[0] = result;
        }
      }
    }

    if (aggressive_shape_inference_) {
      // Update output shapes with annotated information. This is optional.
      UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();

      // Update output tensor values using EvaluateNode() if we can.
      // Due to the cost of EvaluateNode(), we run it only for certain op types
      // (white listed) and small integer tensors.

      const int max_element_size = 17;  // Max up to 4x4 matrix or similar.
      if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
          !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
        return Status::OK();
      }
      UpdateOutputShapesAndValues(node, c).IgnoreError();  // This is optional.
    }
    return Status::OK();
  }