StructInfo InferStructInfoIndexTensor()

in src/relax/op/tensor/manipulate.cc [486:613]


StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) {
  if (call->args.size() != 2) {
    ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments");
  }

  TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
  Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]);

  if (indices_sinfo.empty()) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "index_tensor expects a non‑empty tuple of index tensors");
  }

  DataType output_dtype = data_sinfo->dtype;
  int n_indices = static_cast<int>(indices_sinfo.size());
  Optional<VDevice> vdev = data_sinfo->vdevice;

  // Indices must be integers
  for (int i = 0; i < n_indices; ++i) {
    const auto& s = indices_sinfo[i];
    if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "index_tensor requires every index tensor to have an integer dtype; "
                       << "index " << i << " has dtype " << s->dtype);
    }
  }

  // Count of indices must be less than or equal to data.ndim
  if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "index_tensor received " << n_indices
                     << " index tensors, but data has only " << data_sinfo->ndim << " dimensions");
  }

  arith::Analyzer* analyzer = ctx->GetAnalyzer();
  bool all_index_have_shape_value = true;
  std::vector<Array<PrimExpr>> index_shapes;
  int max_index_ndim = 0;

  for (const auto& s : indices_sinfo) {
    const auto* shp = s->shape.as<ShapeExprNode>();
    if (!shp) {
      all_index_have_shape_value = false;
    } else {
      index_shapes.push_back(shp->values);
      max_index_ndim = std::max(max_index_ndim, static_cast<int>(shp->values.size()));
    }
    if (!s->IsUnknownNdim()) {
      max_index_ndim = std::max(max_index_ndim, s->ndim);
    }
  }

  Optional<Array<PrimExpr>> broadcast_shape;
  bool shape_unknown = !all_index_have_shape_value;

  if (all_index_have_shape_value) {
    // initialise broadcast result with 1’s
    Array<PrimExpr> out_shape;
    for (int i = 0; i < max_index_ndim; ++i) {
      out_shape.push_back(IntImm(DataType::Int(64), 1));
    }

    for (const auto& ishape : index_shapes) {
      int cur_ndim = ishape.size();
      for (int axis = 0; axis < max_index_ndim; ++axis) {
        int lhs_axis = max_index_ndim - 1 - axis;  // aligned from right
        int rhs_axis = cur_ndim - 1 - axis;
        if (rhs_axis < 0) break;  // shorter rank – done

        PrimExpr lhs_dim = out_shape[lhs_axis];
        PrimExpr rhs_dim = ishape[rhs_axis];

        const auto* lhs_int = lhs_dim.as<IntImmNode>();
        const auto* rhs_int = rhs_dim.as<IntImmNode>();

        // Case 1: current broadcast slot is 1 -> always replace
        if (lhs_int && lhs_int->value == 1) {
          out_shape.Set(lhs_axis, rhs_dim);
          continue;
        }
        // Case 2: rhs is 1 -> keep lhs_dim unchanged
        if (rhs_int && rhs_int->value == 1) {
          continue;
        }
        // Both are non‑one constants: must equal
        if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
          ctx->ReportFatal(Diagnostic::Error(call)
                           << "index_tensor: cannot broadcast index shapes. Mismatch at axis "
                           << lhs_axis << ": " << lhs_dim << " vs " << rhs_dim);
        }
        // Give up if not provablt equal
        if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
          shape_unknown = true;
          break;
        }
      }
      if (shape_unknown) break;
    }

    if (!shape_unknown) broadcast_shape = out_shape;
  }

  // Count of dimensions in output
  int out_ndim = kUnknownNDim;
  if (!data_sinfo->IsUnknownNdim()) {
    int tail_ndim = data_sinfo->ndim - n_indices;
    if (broadcast_shape.defined()) {
      out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
    } else if (!shape_unknown) {
      out_ndim = max_index_ndim + tail_ndim;
    }
  }

  // Derive output shape
  if (broadcast_shape.defined()) {
    const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
    if (data_shape_expr) {
      Array<PrimExpr> result_shape = broadcast_shape.value();
      for (int i = n_indices; i < data_sinfo->ndim; ++i) {
        result_shape.push_back(data_shape_expr->values[i]);
      }
      return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
    }
  }

  // Unknown output shape
  return TensorStructInfo(output_dtype, out_ndim, vdev);
}