in src/relax/op/memory/view.cc [45:290]
StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 4) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " should receive 4 arguments, "
<< "but received " << call->args);
}
Expr arg_data = call->args[0];
Expr arg_shape = call->args[1];
Expr arg_dtype = call->args[2];
Expr arg_relative_byte_offset = call->args[3];
TensorStructInfo data_sinfo = [&]() -> TensorStructInfo {
StructInfo sinfo = GetStructInfo(arg_data);
if (auto opt = sinfo.as<TensorStructInfo>()) {
return opt.value();
} else {
LOG(FATAL) << "TypeError: "
<< "Operator " << call->op << " expects first argument to be a tensor, "
<< "but received " << arg_data << " with type " << sinfo;
}
}();
auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* {
StructInfo sinfo = GetStructInfo(arg_shape);
if (HasVoidStructInfo(arg_shape)) {
// No shape change is applied. The input tensor's shape is
// kept as-is.
return nullptr;
} else if (auto ptr = sinfo.as<ShapeStructInfoNode>()) {
// The `R.view` operation returns a different shape.
return ptr;
} else {
LOG(FATAL) << "TypeError: "
<< "Operator " << call->op << " expects second argument to be a ShapeExpr, "
<< "or a void-type (empty relax tuple), "
<< "but received " << arg_shape << " with type " << sinfo;
}
}();
auto view_dtype = [&]() -> std::optional<DataType> {
StructInfo sinfo = GetStructInfo(arg_dtype);
if (HasVoidStructInfo(arg_dtype)) {
// No datatype change is applied. The input tensor's dtype is
// kept as-is.
return std::nullopt;
}
Expr arg_value = arg_dtype;
while (auto arg_var = arg_value.as<Var>()) {
if (auto bound_value = ctx->LookupBinding(arg_var.value())) {
arg_value = bound_value.value();
} else {
break;
}
}
// In general, StructInfo inference should only depend on the
// StructInfo of the arguments, and not on the arguments
// themselves. However, `relax::DataTypeImm` uses
// `ObjectStructInfo`, so we need to inspect the argument itself
// in this case.
if (auto dtype_imm = arg_value.as<DataTypeImmNode>()) {
// We know the datatype for the view.
return dtype_imm->value;
} else if (sinfo.as<ObjectStructInfoNode>()) {
// The view changes the datatype, but we don't know what it is
// being changed into.
return DataType::Void();
} else {
LOG(FATAL) << "TypeError: "
<< "Operator " << call->op
<< " expects the dtype argument to be a relax::DataTypeImm, "
<< "but received " << arg_dtype << " with type " << sinfo;
}
}();
auto view_relative_byte_offset = [&]() -> Optional<PrimExpr> {
StructInfo sinfo = GetStructInfo(arg_relative_byte_offset);
if (HasVoidStructInfo(arg_relative_byte_offset)) {
// No byte offset is specified, so no change is applied.
return IntImm(DataType::Int(64), 0);
} else if (auto prim_sinfo = sinfo.as<PrimStructInfoNode>()) {
CHECK_EQ(prim_sinfo->dtype, DataType::Int(64))
<< "TypeError: "
<< "Operator " << call->op
<< " expects the relative_byte_offset to be a 64-bit integer, but received "
<< arg_relative_byte_offset << ", which has type " << sinfo;
if (prim_sinfo->value.defined()) {
// An offset of known value is applied. The known value may
// be dynamic.
return prim_sinfo->value.value();
} else {
// An offset of unknown value is applied.
return NullOpt;
}
} else {
LOG(FATAL) << "TypeError: "
<< "Operator " << call->op << " expects the relative_byte_offset argument "
<< "to be a Relax PrimValue. "
<< "However, expression " << call << " provides relative_byte_offset of "
<< arg_relative_byte_offset << ", which has type " << sinfo;
}
}();
Optional<Array<PrimExpr>> input_shape = data_sinfo->GetShape();
Optional<Array<PrimExpr>> output_shape = NullOpt;
int output_ndim = kUnknownNDim;
if (view_shape_sinfo && view_shape_sinfo->values.defined()) {
output_shape = view_shape_sinfo->values.value();
} else if (view_shape_sinfo) {
output_ndim = view_shape_sinfo->ndim;
} else if (input_shape) {
output_shape = input_shape;
} else {
output_ndim = data_sinfo->ndim;
}
DataType output_dtype = view_dtype.value_or(data_sinfo->dtype);
// Helper function, returns the number of bytes per vectorized
// element. Cannot use `DataType::bytes`, as it returns the
// number of bytes per scalar element.
auto get_size_bytes = [](const DataType& dtype) -> Optional<IntImm> {
if (dtype.is_void()) {
return NullOpt;
} else {
auto size_bits = dtype.bits() * dtype.lanes();
return IntImm(DataType::Int(64), (size_bits + 7) / 8);
}
};
// Helper function, returns the number of elements in an array,
// given the shape of that array.
auto get_num_elements = [&ctx](const Optional<Array<PrimExpr>>& shape) -> Optional<PrimExpr> {
if (!shape.defined()) {
return NullOpt;
}
PrimExpr num_elements = Integer(1);
for (const auto& dim : shape.value()) {
num_elements *= dim;
}
return ctx->GetAnalyzer()->Simplify(num_elements);
};
Optional<PrimExpr> input_nelements = get_num_elements(input_shape);
Optional<PrimExpr> output_nelements = get_num_elements(output_shape);
Optional<IntImm> input_element_size = get_size_bytes(data_sinfo->dtype);
Optional<IntImm> output_element_size = get_size_bytes(output_dtype);
if (input_nelements && output_nelements && input_element_size && output_element_size &&
view_relative_byte_offset) {
// The shapes and dtype of input and output are known. We know
// the byte_offset that is applied, and can verify that the view
// does not overrun the bounds of the original array.
PrimExpr input_nbytes = input_nelements.value() * input_element_size.value();
PrimExpr output_nbytes = output_nelements.value() * output_element_size.value();
PrimExpr view_end = output_nbytes + view_relative_byte_offset.value();
if (ctx->GetAnalyzer()->CanProve(output_nbytes + view_relative_byte_offset.value() >
input_nbytes)) {
LOG(FATAL) << "ValueError: "
<< "Views into an array must not exceed the bounds of the array being viewed. "
<< "However, expression " << call << " attempted to create view of type "
<< TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype)
<< " with relative byte offset " << view_relative_byte_offset
<< ", viewing into the array " << arg_data << " of type " << data_sinfo << ". "
<< "The end of the view would occur at byte " << view_end
<< ", relative to the start of array " << arg_data << ", but " << arg_data
<< " is only " << input_nbytes << " long.";
}
} else if (input_nelements && output_nelements && input_element_size && output_element_size) {
// The shapes and dtype of input and output are known. However,
// we don't know if the `byte_offset` is being adjusted. We can
// still check validate using the size of the view. If the view
// is larger than the original array, then it would overrun its
// bounds regardless of the `relative_byte_offset` being applied.
PrimExpr input_nbytes = input_nelements.value() * input_element_size.value();
PrimExpr output_nbytes = output_nelements.value() * output_element_size.value();
if (ctx->GetAnalyzer()->CanProve(output_nbytes > input_nbytes)) {
LOG(FATAL) << "ValueError: "
<< "Views into an array must not exceed the bounds of the array being viewed. "
<< "However, expression " << call << " attempted to create view of type "
<< TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype)
<< " from input array of type " << data_sinfo << ". "
<< "This view would increase the size from " << output_nbytes << " bytes to "
<< output_nbytes << " bytes.";
}
} else if (input_element_size && output_element_size && !view_shape_sinfo) {
// The output view has a known dtype, which is different from the
// known dtype of the input array. Because the view's shape is
// the same as the original array, when counted in number of
// elements, an increase to the per-element size would cause the
// view to be larger than the original array.
CHECK_GE(input_element_size.value()->value, output_element_size.value()->value)
<< "ValueError: "
<< "Operator " << call->op
<< " may not produce a view that exceeds the bounds of the original array. "
<< "In expression " << call << " the data type is changed from " << data_sinfo->dtype
<< " to " << view_dtype.value() << ", increasing the size per element from "
<< input_element_size << " bytes to " << output_element_size << " bytes. "
<< "Consider providing a new shape for the R.view.";
} else if (input_nelements && output_nelements && !view_dtype) {
// The shape is being updated, while keeping the datatype the
// same. Even though we don't know the size of each element, we
// know it must be the same for the input and output arrays. An
// increase to the number of elements would cause the view to be
// larger than the original array, regardless of the size of each
// individual element.
if (ctx->GetAnalyzer()->CanProve(output_nelements.value() > input_nelements.value())) {
LOG(FATAL) << "ValueError: "
<< "Views into an array must not exceed the bounds of the array being viewed. "
<< "However, expression " << call << " attempted to view array " << arg_data
<< " (shape = " << input_shape << ", " << input_nelements << " elements) as shape "
<< output_shape << " with " << output_nelements << " elements.";
}
} else if (view_relative_byte_offset && !view_shape_sinfo && !view_dtype) {
// The byte_offset is being updated, but neither the shape nor the
// dtype is changing. Any non-zero offset will cause the view to
// overrun the bounds of the original array.
if (ctx->GetAnalyzer()->CanProve(view_relative_byte_offset.value() > 0)) {
LOG(FATAL) << "ValueError: "
<< "Views into an array must not exceed the bounds of the array being viewed. "
<< "However, expression " << call << " attempted to offset the view by "
<< view_relative_byte_offset << " bytes, "
<< "without reducing either the number of elements in the view "
<< "or the size of each element.";
}
}
if (output_shape.defined()) {
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, data_sinfo->vdevice);
} else {
return TensorStructInfo(output_dtype, output_ndim, data_sinfo->vdevice);
}
}