in src/relax/op/tensor/manipulate.cc [181:298]
StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument");
}
Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);
if (tensor_sinfo.empty()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Concat op expects at least one tensor in the input Tuple. However, the "
"given input Tuple is empty.");
}
const auto* attrs = call->attrs.as<ConcatAttrs>();
int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1;
DataType output_dtype = DataType::Void();
Optional<VDevice> vdev = NullOpt;
bool shape_unknown = false;
bool is_void_dtype = false;
bool vdevice_unknown = false;
std::vector<Array<PrimExpr>> shape_values;
shape_values.reserve(tensor_sinfo.size());
for (TensorStructInfo sinfo : tensor_sinfo) {
// Update the output dtype.
if (sinfo->dtype.is_void()) {
is_void_dtype = true;
} else if (output_dtype.is_void()) {
output_dtype = sinfo->dtype;
} else if (sinfo->dtype != output_dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Concat expects all input tensors to have the same dtype. However, the "
"input contains tensors with dtype "
<< output_dtype << " and " << sinfo->dtype);
}
// Update the output ndim.
// Todo(relax-team): revisit here for better check on if the input tensor has
// ndim 1 when the input axis is undefined.
if (output_ndim == kUnknownNDim) {
output_ndim = sinfo->ndim;
} else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Concat expects all input tensors to have same ndim. However, the "
"input contains tensors with ndim "
<< output_ndim << " and " << sinfo->ndim);
}
// Update the virtual device.
if (!vdevice_unknown) {
if (sinfo->vdevice.defined()) {
if (!vdev.defined()) {
vdev = sinfo->vdevice.value();
} else if (sinfo->vdevice.value()->target.defined()) {
// mismatch
if (sinfo->vdevice.value() != vdev) {
vdevice_unknown = true;
}
}
}
}
// Update the shape values for best effort check.
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
if (shape_expr != nullptr) {
shape_values.push_back(shape_expr->values);
continue;
}
shape_unknown = true;
if (!sinfo->shape.defined()) {
continue;
}
// Keep the shape value for equality check.
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
if (shape_sinfo->values.defined()) {
shape_values.push_back(shape_sinfo->values.value());
}
}
if (is_void_dtype) {
output_dtype = DataType::Void();
}
if (vdevice_unknown) {
vdev = NullOpt;
}
if (output_ndim == kUnknownNDim) {
return tensor_sinfo.size() == 1 ? tensor_sinfo[0]
: TensorStructInfo(output_dtype, output_ndim, vdev);
}
int axis =
attrs->axis.has_value() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()) : 0;
// If there is only one input tensor, no action is needed.
if (tensor_sinfo.size() == 1) {
return tensor_sinfo[0];
}
if (shape_values.empty()) {
if (!vdevice_unknown) {
return TensorStructInfo(output_dtype, output_ndim, vdev);
}
return TensorStructInfo(output_dtype, output_ndim);
}
// As long as the there is known shape value, we will do the best effort check to ensure safety.
Optional<Array<PrimExpr>> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis);
if (shape_unknown || !output_shape.defined()) {
if (!vdevice_unknown) {
return TensorStructInfo(output_dtype, output_ndim, vdev);
}
return TensorStructInfo(output_dtype, output_ndim);
} else {
if (!vdevice_unknown) {
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
}
return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
}
}