StructInfo InferStructInfoStridedSlice()

in src/relax/op/tensor/index.cc [264:424]


StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) {
  size_t n_args = call->args.size();
  CHECK(4 <= n_args && n_args <= 5)
      << "Operator " << call->op << " accepts either three arguments (data, axes, begin, end) "
      << " or four arguments (data, axes, begin, end, strides), "
      << "but received " << n_args << " in expression " << call;

  Expr data = call->args[0];
  Expr axes = call->args[1];
  Expr begin = call->args[2];
  Expr end = call->args[3];
  Optional<Expr> strides = [&]() -> Optional<Expr> {
    if (n_args > 4) {
      return call->args[4];
    } else {
      return NullOpt;
    }
  }();

  auto axes_sinfo = GetStructInfo(call->args[1]);
  auto begin_sinfo = GetStructInfo(call->args[2]);
  auto end_sinfo = GetStructInfo(call->args[3]);
  auto strides_sinfo = [&]() -> Optional<StructInfo> {
    if (n_args > 4) {
      return GetStructInfo(call->args[4]);
    } else {
      return NullOpt;
    }
  }();

  CHECK(IsBaseOf(relax::TensorStructInfo(DataType::Void(), kUnknownNDim), GetStructInfo(data)))
      << "Operator " << call->op << " requires the first argument to be a tensor.  "
      << "However, in expression " << call << ", the first argument " << data << " has struct info "
      << GetStructInfo(data);

  // TODO(Lunderberg): Implement this check using `IsBaseOf`.  Doing
  // so will require a way to represent a `relax::TupleStructInfo` of
  // unknown length, where each element has the same `StructInfo`.
  auto is_base_of_tuple_of_int64 = [&](const StructInfo& sinfo) -> bool {
    if (sinfo.as<ObjectStructInfoNode>()) {
      return true;
    }

    const auto* tuple = sinfo.as<TupleStructInfoNode>();
    if (!tuple) return false;

    return std::all_of(tuple->fields.begin(), tuple->fields.end(), [](const StructInfo& field) {
      return IsBaseOf(relax::PrimStructInfo(DataType::Int(64)), field);
    });
  };
  auto check_tuple = [&](const char* name, Expr expr) {
    auto sinfo = GetStructInfo(expr);

    CHECK(is_base_of_tuple_of_int64(sinfo)) << "Operator " << call->op << " requires the " << name
                                            << " argument to be a tuple of int64 PrimValues.  "
                                            << "However, in expression " << call << ", the " << name
                                            << " argument " << expr << " has struct info " << sinfo;
  };
  check_tuple("axes", call->args[1]);
  check_tuple("begin", call->args[2]);
  check_tuple("end", call->args[3]);
  if (call->args.size() > 4) {
    check_tuple("strides", call->args[4]);
  }

  const auto* data_sinfo = data->struct_info_.as<TensorStructInfoNode>();

  DataType dtype = DataType::Void();
  Optional<VDevice> vdevice = NullOpt;
  int ndim = kUnknownNDim;
  if (data_sinfo) {
    dtype = data_sinfo->dtype;
    vdevice = data_sinfo->vdevice;
    ndim = data_sinfo->ndim;
  }

  Optional<Expr> shape = [&]() -> Optional<Expr> {
    if (!data_sinfo) return NullOpt;
    if (!data_sinfo->shape) return NullOpt;

    auto opt_axes_tuple = UnpackTupleOfPrimValue<Integer>(axes);
    if (!opt_axes_tuple) return NullOpt;
    auto axes_tuple = opt_axes_tuple.value();

    auto opt_begin_tuple = UnpackTupleOfPrimValue(begin);
    if (!opt_begin_tuple) return NullOpt;
    auto begin_tuple = opt_begin_tuple.value();

    CHECK_EQ(axes_tuple.size(), begin_tuple.size())
        << "For operator " << call->op << ", "
        << "the number of axes provided must match the number of 'begin' indices.  "
        << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
        << ") and " << begin_tuple.size() << " 'begin' indices specified (" << begin_tuple << ")";

    auto opt_end_tuple = UnpackTupleOfPrimValue(end);
    if (!opt_end_tuple) return NullOpt;
    auto end_tuple = opt_end_tuple.value();

    CHECK_EQ(axes_tuple.size(), end_tuple.size())
        << "For operator " << call->op << ", "
        << "the number of axes provided must match the number of 'end' indices.  "
        << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
        << ") and " << end_tuple.size() << " 'end' indices specified (" << end_tuple << ")";

    Array<PrimExpr> strides_tuple;
    if (strides.defined()) {
      auto opt_strides_tuple = UnpackTupleOfPrimValue(strides);
      if (!opt_strides_tuple) return NullOpt;

      strides_tuple = opt_strides_tuple.value();
    } else {
      strides_tuple = Array<PrimExpr>(axes_tuple.size(), IntImm(DataType::Int(64), 1));
    }

    CHECK_EQ(axes_tuple.size(), strides_tuple.size())
        << "For operator " << call->op << ", "
        << "when the optional 'strides' argument is provided, "
        << "the number of axes provided must match the number of strides provided.  "
        << "However, there are " << axes_tuple.size() << " axes specified (" << axes_tuple
        << ") and " << strides_tuple.size() << " strides specified (" << strides_tuple << ")";

    auto opt_data_shape = data_sinfo->GetShape();

    if (axes_tuple.empty() && !opt_data_shape.defined()) {
      return data_sinfo->shape.value();
    } else if (!opt_data_shape.defined()) {
      return NullOpt;
    }

    std::vector<int> axes = NormalizeAxes(call, ctx, data_sinfo->ndim, axes_tuple);
    auto attrs = call->attrs.as<StridedSliceAttrs>();

    Array<PrimExpr> output_shape = data_sinfo->GetShape().value();
    for (size_t i = 0; i < axes.size(); i++) {
      size_t axis = axes[i];
      PrimExpr input_dim = output_shape[axis];
      PrimExpr begin = begin_tuple[i];
      PrimExpr end = end_tuple[i];

      PrimExpr output_dim =
          topi::GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound);

      arith::Analyzer* analyzer = ctx->GetAnalyzer();
      std::optional<With<arith::ConstraintContext>> context;
      if (attrs->assume_inbound) {
        context.emplace(analyzer, 0 <= begin && begin <= input_dim && 0 <= end && end <= input_dim);
      }

      output_dim = analyzer->Simplify(output_dim);

      output_shape.Set(axis, output_dim);
    }
    return ShapeExpr(output_shape);
  }();

  if (shape.defined()) {
    return TensorStructInfo(shape.value(), dtype, vdevice);
  } else {
    return TensorStructInfo(dtype, ndim, vdevice);
  }
}