in src/relax/op/tensor/manipulate.cc [486:613]
StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 2) {
ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments");
}
TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]);
if (indices_sinfo.empty()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor expects a non‑empty tuple of index tensors");
}
DataType output_dtype = data_sinfo->dtype;
int n_indices = static_cast<int>(indices_sinfo.size());
Optional<VDevice> vdev = data_sinfo->vdevice;
// Indices must be integers
for (int i = 0; i < n_indices; ++i) {
const auto& s = indices_sinfo[i];
if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor requires every index tensor to have an integer dtype; "
<< "index " << i << " has dtype " << s->dtype);
}
}
// Count of indices must be less than or equal to data.ndim
if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor received " << n_indices
<< " index tensors, but data has only " << data_sinfo->ndim << " dimensions");
}
arith::Analyzer* analyzer = ctx->GetAnalyzer();
bool all_index_have_shape_value = true;
std::vector<Array<PrimExpr>> index_shapes;
int max_index_ndim = 0;
for (const auto& s : indices_sinfo) {
const auto* shp = s->shape.as<ShapeExprNode>();
if (!shp) {
all_index_have_shape_value = false;
} else {
index_shapes.push_back(shp->values);
max_index_ndim = std::max(max_index_ndim, static_cast<int>(shp->values.size()));
}
if (!s->IsUnknownNdim()) {
max_index_ndim = std::max(max_index_ndim, s->ndim);
}
}
Optional<Array<PrimExpr>> broadcast_shape;
bool shape_unknown = !all_index_have_shape_value;
if (all_index_have_shape_value) {
// initialise broadcast result with 1’s
Array<PrimExpr> out_shape;
for (int i = 0; i < max_index_ndim; ++i) {
out_shape.push_back(IntImm(DataType::Int(64), 1));
}
for (const auto& ishape : index_shapes) {
int cur_ndim = ishape.size();
for (int axis = 0; axis < max_index_ndim; ++axis) {
int lhs_axis = max_index_ndim - 1 - axis; // aligned from right
int rhs_axis = cur_ndim - 1 - axis;
if (rhs_axis < 0) break; // shorter rank – done
PrimExpr lhs_dim = out_shape[lhs_axis];
PrimExpr rhs_dim = ishape[rhs_axis];
const auto* lhs_int = lhs_dim.as<IntImmNode>();
const auto* rhs_int = rhs_dim.as<IntImmNode>();
// Case 1: current broadcast slot is 1 -> always replace
if (lhs_int && lhs_int->value == 1) {
out_shape.Set(lhs_axis, rhs_dim);
continue;
}
// Case 2: rhs is 1 -> keep lhs_dim unchanged
if (rhs_int && rhs_int->value == 1) {
continue;
}
// Both are non‑one constants: must equal
if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor: cannot broadcast index shapes. Mismatch at axis "
<< lhs_axis << ": " << lhs_dim << " vs " << rhs_dim);
}
// Give up if not provablt equal
if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
shape_unknown = true;
break;
}
}
if (shape_unknown) break;
}
if (!shape_unknown) broadcast_shape = out_shape;
}
// Count of dimensions in output
int out_ndim = kUnknownNDim;
if (!data_sinfo->IsUnknownNdim()) {
int tail_ndim = data_sinfo->ndim - n_indices;
if (broadcast_shape.defined()) {
out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
} else if (!shape_unknown) {
out_ndim = max_index_ndim + tail_ndim;
}
}
// Derive output shape
if (broadcast_shape.defined()) {
const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape_expr) {
Array<PrimExpr> result_shape = broadcast_shape.value();
for (int i = n_indices; i < data_sinfo->ndim; ++i) {
result_shape.push_back(data_shape_expr->values[i]);
}
return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
}
}
// Unknown output shape
return TensorStructInfo(output_dtype, out_ndim, vdev);
}