StructInfo InferStructInfoView()

in src/relax/op/memory/view.cc [45:290]


StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {
  if (call->args.size() != 4) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "Operator " << call->op << " should receive 4 arguments, "
                     << "but received " << call->args);
  }
  Expr arg_data = call->args[0];
  Expr arg_shape = call->args[1];
  Expr arg_dtype = call->args[2];
  Expr arg_relative_byte_offset = call->args[3];

  TensorStructInfo data_sinfo = [&]() -> TensorStructInfo {
    StructInfo sinfo = GetStructInfo(arg_data);
    if (auto opt = sinfo.as<TensorStructInfo>()) {
      return opt.value();
    } else {
      LOG(FATAL) << "TypeError: "
                 << "Operator " << call->op << " expects first argument to be a tensor, "
                 << "but received " << arg_data << " with type " << sinfo;
    }
  }();
  auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* {
    StructInfo sinfo = GetStructInfo(arg_shape);
    if (HasVoidStructInfo(arg_shape)) {
      // No shape change is applied.  The input tensor's shape is
      // kept as-is.
      return nullptr;
    } else if (auto ptr = sinfo.as<ShapeStructInfoNode>()) {
      // The `R.view` operation returns a different shape.
      return ptr;
    } else {
      LOG(FATAL) << "TypeError: "
                 << "Operator " << call->op << " expects second argument to be a ShapeExpr, "
                 << "or a void-type (empty relax tuple), "
                 << "but received " << arg_shape << " with type " << sinfo;
    }
  }();

  auto view_dtype = [&]() -> std::optional<DataType> {
    StructInfo sinfo = GetStructInfo(arg_dtype);

    if (HasVoidStructInfo(arg_dtype)) {
      // No datatype change is applied.  The input tensor's dtype is
      // kept as-is.
      return std::nullopt;
    }

    Expr arg_value = arg_dtype;
    while (auto arg_var = arg_value.as<Var>()) {
      if (auto bound_value = ctx->LookupBinding(arg_var.value())) {
        arg_value = bound_value.value();
      } else {
        break;
      }
    }

    // In general, StructInfo inference should only depend on the
    // StructInfo of the arguments, and not on the arguments
    // themselves.  However, `relax::DataTypeImm` uses
    // `ObjectStructInfo`, so we need to inspect the argument itself
    // in this case.
    if (auto dtype_imm = arg_value.as<DataTypeImmNode>()) {
      // We know the datatype for the view.
      return dtype_imm->value;
    } else if (sinfo.as<ObjectStructInfoNode>()) {
      // The view changes the datatype, but we don't know what it is
      // being changed into.
      return DataType::Void();
    } else {
      LOG(FATAL) << "TypeError: "
                 << "Operator " << call->op
                 << " expects the dtype argument to be a relax::DataTypeImm, "
                 << "but received " << arg_dtype << " with type " << sinfo;
    }
  }();

  auto view_relative_byte_offset = [&]() -> Optional<PrimExpr> {
    StructInfo sinfo = GetStructInfo(arg_relative_byte_offset);

    if (HasVoidStructInfo(arg_relative_byte_offset)) {
      // No byte offset is specified, so no change is applied.
      return IntImm(DataType::Int(64), 0);
    } else if (auto prim_sinfo = sinfo.as<PrimStructInfoNode>()) {
      CHECK_EQ(prim_sinfo->dtype, DataType::Int(64))
          << "TypeError: "
          << "Operator " << call->op
          << " expects the relative_byte_offset to be a 64-bit integer, but received "
          << arg_relative_byte_offset << ", which has type " << sinfo;
      if (prim_sinfo->value.defined()) {
        // An offset of known value is applied.  The known value may
        // be dynamic.
        return prim_sinfo->value.value();
      } else {
        // An offset of unknown value is applied.
        return NullOpt;
      }
    } else {
      LOG(FATAL) << "TypeError: "
                 << "Operator " << call->op << " expects the relative_byte_offset argument "
                 << "to be a Relax PrimValue.  "
                 << "However, expression " << call << " provides relative_byte_offset of "
                 << arg_relative_byte_offset << ", which has type " << sinfo;
    }
  }();

  Optional<Array<PrimExpr>> input_shape = data_sinfo->GetShape();

  Optional<Array<PrimExpr>> output_shape = NullOpt;
  int output_ndim = kUnknownNDim;
  if (view_shape_sinfo && view_shape_sinfo->values.defined()) {
    output_shape = view_shape_sinfo->values.value();
  } else if (view_shape_sinfo) {
    output_ndim = view_shape_sinfo->ndim;
  } else if (input_shape) {
    output_shape = input_shape;
  } else {
    output_ndim = data_sinfo->ndim;
  }

  DataType output_dtype = view_dtype.value_or(data_sinfo->dtype);

  // Helper function, returns the number of bytes per vectorized
  // element.  Cannot use `DataType::bytes`, as it returns the
  // number of bytes per scalar element.
  auto get_size_bytes = [](const DataType& dtype) -> Optional<IntImm> {
    if (dtype.is_void()) {
      return NullOpt;
    } else {
      auto size_bits = dtype.bits() * dtype.lanes();
      return IntImm(DataType::Int(64), (size_bits + 7) / 8);
    }
  };

  // Helper function, returns the number of elements in an array,
  // given the shape of that array.
  auto get_num_elements = [&ctx](const Optional<Array<PrimExpr>>& shape) -> Optional<PrimExpr> {
    if (!shape.defined()) {
      return NullOpt;
    }

    PrimExpr num_elements = Integer(1);
    for (const auto& dim : shape.value()) {
      num_elements *= dim;
    }
    return ctx->GetAnalyzer()->Simplify(num_elements);
  };

  Optional<PrimExpr> input_nelements = get_num_elements(input_shape);
  Optional<PrimExpr> output_nelements = get_num_elements(output_shape);

  Optional<IntImm> input_element_size = get_size_bytes(data_sinfo->dtype);
  Optional<IntImm> output_element_size = get_size_bytes(output_dtype);

  if (input_nelements && output_nelements && input_element_size && output_element_size &&
      view_relative_byte_offset) {
    // The shapes and dtype of input and output are known.  We know
    // the byte_offset that is applied, and can verify that the view
    // does not overrun the bounds of the original array.

    PrimExpr input_nbytes = input_nelements.value() * input_element_size.value();
    PrimExpr output_nbytes = output_nelements.value() * output_element_size.value();
    PrimExpr view_end = output_nbytes + view_relative_byte_offset.value();

    if (ctx->GetAnalyzer()->CanProve(output_nbytes + view_relative_byte_offset.value() >
                                     input_nbytes)) {
      LOG(FATAL) << "ValueError: "
                 << "Views into an array must not exceed the bounds of the array being viewed.  "
                 << "However, expression " << call << " attempted to create view of type "
                 << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype)
                 << " with relative byte offset " << view_relative_byte_offset
                 << ", viewing into the array " << arg_data << " of type " << data_sinfo << ".  "
                 << "The end of the view would occur at byte " << view_end
                 << ", relative to the start of array " << arg_data << ", but " << arg_data
                 << " is only " << input_nbytes << " long.";
    }

  } else if (input_nelements && output_nelements && input_element_size && output_element_size) {
    // The shapes and dtype of input and output are known.  However,
    // we don't know if the `byte_offset` is being adjusted.  We can
    // still check validate using the size of the view.  If the view
    // is larger than the original array, then it would overrun its
    // bounds regardless of the `relative_byte_offset` being applied.

    PrimExpr input_nbytes = input_nelements.value() * input_element_size.value();
    PrimExpr output_nbytes = output_nelements.value() * output_element_size.value();

    if (ctx->GetAnalyzer()->CanProve(output_nbytes > input_nbytes)) {
      LOG(FATAL) << "ValueError: "
                 << "Views into an array must not exceed the bounds of the array being viewed.  "
                 << "However, expression " << call << " attempted to create view of type "
                 << TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype)
                 << " from input array of type " << data_sinfo << ".  "
                 << "This view would increase the size from " << output_nbytes << " bytes to "
                 << output_nbytes << " bytes.";
    }

  } else if (input_element_size && output_element_size && !view_shape_sinfo) {
    // The output view has a known dtype, which is different from the
    // known dtype of the input array.  Because the view's shape is
    // the same as the original array, when counted in number of
    // elements, an increase to the per-element size would cause the
    // view to be larger than the original array.

    CHECK_GE(input_element_size.value()->value, output_element_size.value()->value)
        << "ValueError: "
        << "Operator " << call->op
        << " may not produce a view that exceeds the bounds of the original array.  "
        << "In expression " << call << " the data type is changed from " << data_sinfo->dtype
        << " to " << view_dtype.value() << ", increasing the size per element from "
        << input_element_size << " bytes to " << output_element_size << " bytes.  "
        << "Consider providing a new shape for the R.view.";
  } else if (input_nelements && output_nelements && !view_dtype) {
    // The shape is being updated, while keeping the datatype the
    // same.  Even though we don't know the size of each element, we
    // know it must be the same for the input and output arrays.  An
    // increase to the number of elements would cause the view to be
    // larger than the original array, regardless of the size of each
    // individual element.

    if (ctx->GetAnalyzer()->CanProve(output_nelements.value() > input_nelements.value())) {
      LOG(FATAL) << "ValueError: "
                 << "Views into an array must not exceed the bounds of the array being viewed.  "
                 << "However, expression " << call << " attempted to view array " << arg_data
                 << " (shape = " << input_shape << ", " << input_nelements << " elements) as shape "
                 << output_shape << " with " << output_nelements << " elements.";
    }
  } else if (view_relative_byte_offset && !view_shape_sinfo && !view_dtype) {
    // The byte_offset is being updated, but neither the shape nor the
    // dtype is changing.  Any non-zero offset will cause the view to
    // overrun the bounds of the original array.
    if (ctx->GetAnalyzer()->CanProve(view_relative_byte_offset.value() > 0)) {
      LOG(FATAL) << "ValueError: "
                 << "Views into an array must not exceed the bounds of the array being viewed.  "
                 << "However, expression " << call << " attempted to offset the view by "
                 << view_relative_byte_offset << " bytes, "
                 << "without reducing either the number of elements in the view "
                 << "or the size of each element.";
    }
  }

  if (output_shape.defined()) {
    return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, data_sinfo->vdevice);
  } else {
    return TensorStructInfo(output_dtype, output_ndim, data_sinfo->vdevice);
  }
}