in src/relax/op/tensor/manipulate.cc [1391:1510]
StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 argument");
}
Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);
if (tensor_sinfo.empty()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Stack op expects at least one tensor in the input Tuple. "
<< "However, the given input Tuple is empty.");
}
const auto* attrs = call->attrs.as<StackAttrs>();
ICHECK(attrs != nullptr) << "Stack must have StackAttrs";
// Default axis is 0 if not specified
int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension
DataType output_dtype = DataType::Void();
Optional<VDevice> vdev = NullOpt;
bool shape_unknown = false;
bool is_void_dtype = false;
bool vdevice_unknown = false;
std::vector<Array<PrimExpr>> shape_values;
shape_values.reserve(tensor_sinfo.size());
for (TensorStructInfo sinfo : tensor_sinfo) {
// Check dtype consistency
if (sinfo->dtype.is_void()) {
is_void_dtype = true;
} else if (output_dtype.is_void()) {
output_dtype = sinfo->dtype;
} else if (sinfo->dtype != output_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Stack expects all input tensors to have the same dtype. "
<< "Found " << output_dtype << " and " << sinfo->dtype);
}
// Check ndim consistency
if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Stack expects all input tensors to have same ndim. "
<< "Found " << tensor_sinfo[0]->ndim << " and " << sinfo->ndim);
}
// Check virtual device consistency
if (!vdevice_unknown) {
if (sinfo->vdevice.defined()) {
if (!vdev.defined()) {
vdev = sinfo->vdevice.value();
} else if (sinfo->vdevice.value() != vdev) {
vdevice_unknown = true;
}
}
}
// Collect shape information
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
if (shape_expr != nullptr) {
shape_values.push_back(shape_expr->values);
continue;
}
shape_unknown = true;
if (!sinfo->shape.defined()) continue;
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
if (shape_sinfo->values.defined()) {
shape_values.push_back(shape_sinfo->values.value());
}
}
if (is_void_dtype) output_dtype = DataType::Void();
if (vdevice_unknown) vdev = NullOpt;
// Normalize axis (default to 0 if not specified)
int axis =
attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0;
// Single tensor case
if (tensor_sinfo.size() == 1) {
if (shape_values.empty()) {
if (!vdevice_unknown) {
return TensorStructInfo(output_dtype, output_ndim, vdev);
}
return TensorStructInfo(output_dtype, output_ndim);
}
Array<PrimExpr> output_shape;
for (int i = 0; i < axis; ++i) {
output_shape.push_back(shape_values[0][i]);
}
output_shape.push_back(1); // Stack size 1
for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
output_shape.push_back(shape_values[0][i]);
}
if (!vdevice_unknown) {
return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
}
return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
}
// Multiple tensors case
if (shape_values.empty()) {
if (!vdevice_unknown) {
return TensorStructInfo(output_dtype, output_ndim, vdev);
}
return TensorStructInfo(output_dtype, output_ndim);
}
Optional<Array<PrimExpr>> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis);
if (shape_unknown || !output_shape.defined()) {
if (!vdevice_unknown) {
return TensorStructInfo(output_dtype, output_ndim, vdev);
}
return TensorStructInfo(output_dtype, output_ndim);
} else {
if (!vdevice_unknown) {
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
}
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
}
}