in src/relax/op/tensor/index.cc [264:424]
StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) {
size_t n_args = call->args.size();
CHECK(4 <= n_args && n_args <= 5)
<< "Operator " << call->op << " accepts either three arguments (data, axes, begin, end) "
<< " or four arguments (data, axes, begin, end, strides), "
<< "but received " << n_args << " in expression " << call;
Expr data = call->args[0];
Expr axes = call->args[1];
Expr begin = call->args[2];
Expr end = call->args[3];
Optional<Expr> strides = [&]() -> Optional<Expr> {
if (n_args > 4) {
return call->args[4];
} else {
return NullOpt;
}
}();
auto axes_sinfo = GetStructInfo(call->args[1]);
auto begin_sinfo = GetStructInfo(call->args[2]);
auto end_sinfo = GetStructInfo(call->args[3]);
auto strides_sinfo = [&]() -> Optional<StructInfo> {
if (n_args > 4) {
return GetStructInfo(call->args[4]);
} else {
return NullOpt;
}
}();
CHECK(IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim), GetStructInfo(data)))
<< "Operator " << call->op << " requires the first argument to be a tensor. "
<< "However, in expression " << call << ", the first argument " << data << " has struct info "
<< GetStructInfo(data);
// TODO(Lunderberg): Implement this check using `IsBaseOf`. Doing
// so will require a way to represent a `relax::TupleStructInfo` of
// unknown length, where each element has the same `StructInfo`.
auto is_base_of_tuple_of_int64 = [&](const StructInfo& sinfo) -> bool {
if (sinfo.as<ObjectStructInfoNode>()) {
return true;
}
const auto* tuple = sinfo.as<TupleStructInfoNode>();
if (!tuple) return false;
return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const StructInfo& field) {
return IsBaseOf(relax::PrimStructInfo(DataType::Int(64)), field);
});
};
auto check_tuple = [&](const char* name, Expr expr) {
auto sinfo = GetStructInfo(expr);
CHECK(is_base_of_tuple_of_int64(sinfo)) << "Operator " << call->op << " requires the " << name
<< " argument to be a tuple of int64 PrimValues. "
<< "However, in expression " << call << ", the " << name
<< " argument " << expr << " has struct info " << sinfo;
};
check_tuple("axes", call->args[1]);
check_tuple("begin", call->args[2]);
check_tuple("end", call->args[3]);
if (call->args.size() > 4) {
check_tuple("strides", call->args[4]);
}
const auto* data_sinfo = data->struct_info_.as<TensorStructInfoNode>();
DataType dtype = DataType::Void();
Optional<VDevice> vdevice = NullOpt;
int ndim = kUnknownNDim;
if (data_sinfo) {
dtype = data_sinfo->dtype;
vdevice = data_sinfo->vdevice;
ndim = data_sinfo->ndim;
}
Optional<Expr> shape = [&]() -> Optional<Expr> {
if (!data_sinfo) return NullOpt;
if (!data_sinfo->shape) return NullOpt;
auto opt_axes_tuple = UnpackTupleOfPrimValue<Integer>(axes);
if (!opt_axes_tuple) return NullOpt;
auto axes_tuple = opt_axes_tuple.value();
auto opt_begin_tuple = UnpackTupleOfPrimValue(begin);
if (!opt_begin_tuple) return NullOpt;
auto begin_tuple = opt_begin_tuple.value();
CHECK_EQ(axes_tuple.size(), begin_tuple.size())
<< "For operator " << call->op << ", "
<< "the number of axes provided must match the number of 'begin' indices. "
<< "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
<< ") and " << begin_tuple.size() << " 'begin' indices specified (" << begin_tuple << ")";
auto opt_end_tuple = UnpackTupleOfPrimValue(end);
if (!opt_end_tuple) return NullOpt;
auto end_tuple = opt_end_tuple.value();
CHECK_EQ(axes_tuple.size(), end_tuple.size())
<< "For operator " << call->op << ", "
<< "the number of axes provided must match the number of 'end' indices. "
<< "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
<< ") and " << end_tuple.size() << " 'end' indices specified (" << end_tuple << ")";
Array<PrimExpr> strides_tuple;
if (strides.defined()) {
auto opt_strides_tuple = UnpackTupleOfPrimValue(strides);
if (!opt_strides_tuple) return NullOpt;
strides_tuple = opt_strides_tuple.value();
} else {
strides_tuple = Array<PrimExpr>(axes_tuple.size(), IntImm(DataType::Int(64), 1));
}
CHECK_EQ(axes_tuple.size(), strides_tuple.size())
<< "For operator " << call->op << ", "
<< "when the optional 'strides' argument is provided, "
<< "the number of axes provided must match the number of strides provided. "
<< "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
<< ") and " << strides_tuple.size() << " strides specified (" << strides_tuple << ")";
auto opt_data_shape = data_sinfo->GetShape();
if (axes_tuple.empty() && !opt_data_shape.defined()) {
return data_sinfo->shape.value();
} else if (!opt_data_shape.defined()) {
return NullOpt;
}
std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple);
auto attrs = call->attrs.as<StridedSliceAttrs>();
Array<PrimExpr> output_shape = data_sinfo->GetShape().value();
for (size_t i = 0; i < axes.size(); i++) {
size_t axis = axes[i];
PrimExpr input_dim = output_shape[axis];
PrimExpr begin = begin_tuple[i];
PrimExpr end = end_tuple[i];
PrimExpr output_dim =
topi::GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound);
arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::optional<With<arith::ConstraintContext>> context;
if (attrs->assume_inbound) {
context.emplace(analyzer, 0 <= begin && begin <= input_dim && 0 <= end && end <= input_dim);
}
output_dim = analyzer->Simplify(output_dim);
output_shape.Set(axis, output_dim);
}
return ShapeExpr(output_shape);
}();
if (shape.defined()) {
return TensorStructInfo(shape.value(), dtype, vdevice);
} else {
return TensorStructInfo(dtype, ndim, vdevice);
}
}