StructInfo InferStructInfoStack()

in src/relax/op/tensor/manipulate.cc [1391:1510]


StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) {
  if (call->args.size() != 1) {
    ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 argument");
  }

  Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);
  if (tensor_sinfo.empty()) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "Stack op expects at least one tensor in the input Tuple. "
                     << "However, the given input Tuple is empty.");
  }

  const auto* attrs = call->attrs.as<StackAttrs>();
  ICHECK(attrs != nullptr) << "Stack must have StackAttrs";

  // Default axis is 0 if not specified
  int output_ndim = tensor_sinfo[0]->ndim + 1;  // Stack adds one dimension
  DataType output_dtype = DataType::Void();
  Optional<VDevice> vdev = NullOpt;
  bool shape_unknown = false;
  bool is_void_dtype = false;
  bool vdevice_unknown = false;
  std::vector<Array<PrimExpr>> shape_values;
  shape_values.reserve(tensor_sinfo.size());

  for (TensorStructInfo sinfo : tensor_sinfo) {
    // Check dtype consistency
    if (sinfo->dtype.is_void()) {
      is_void_dtype = true;
    } else if (output_dtype.is_void()) {
      output_dtype = sinfo->dtype;
    } else if (sinfo->dtype != output_dtype) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "Stack expects all input tensors to have the same dtype. "
                       << "Found " << output_dtype << " and " << sinfo->dtype);
    }

    // Check ndim consistency
    if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "Stack expects all input tensors to have same ndim. "
                       << "Found " << tensor_sinfo[0]->ndim << " and " << sinfo->ndim);
    }

    // Check virtual device consistency
    if (!vdevice_unknown) {
      if (sinfo->vdevice.defined()) {
        if (!vdev.defined()) {
          vdev = sinfo->vdevice.value();
        } else if (sinfo->vdevice.value() != vdev) {
          vdevice_unknown = true;
        }
      }
    }

    // Collect shape information
    const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
    if (shape_expr != nullptr) {
      shape_values.push_back(shape_expr->values);
      continue;
    }
    shape_unknown = true;

    if (!sinfo->shape.defined()) continue;
    ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
    if (shape_sinfo->values.defined()) {
      shape_values.push_back(shape_sinfo->values.value());
    }
  }

  if (is_void_dtype) output_dtype = DataType::Void();
  if (vdevice_unknown) vdev = NullOpt;

  // Normalize axis (default to 0 if not specified)
  int axis =
      attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0;

  // Single tensor case
  if (tensor_sinfo.size() == 1) {
    if (shape_values.empty()) {
      if (!vdevice_unknown) {
        return TensorStructInfo(output_dtype, output_ndim, vdev);
      }
      return TensorStructInfo(output_dtype, output_ndim);
    }
    Array<PrimExpr> output_shape;
    for (int i = 0; i < axis; ++i) {
      output_shape.push_back(shape_values[0][i]);
    }
    output_shape.push_back(1);  // Stack size 1
    for (int i = axis; i < static_cast<int>(shape_values[0].size()); ++i) {
      output_shape.push_back(shape_values[0][i]);
    }
    if (!vdevice_unknown) {
      return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
    }
    return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
  }

  // Multiple tensors case
  if (shape_values.empty()) {
    if (!vdevice_unknown) {
      return TensorStructInfo(output_dtype, output_ndim, vdev);
    }
    return TensorStructInfo(output_dtype, output_ndim);
  }

  Optional<Array<PrimExpr>> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis);
  if (shape_unknown || !output_shape.defined()) {
    if (!vdevice_unknown) {
      return TensorStructInfo(output_dtype, output_ndim, vdev);
    }
    return TensorStructInfo(output_dtype, output_ndim);
  } else {
    if (!vdevice_unknown) {
      return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
    }
    return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
  }
}