StructInfo InferStructInfoConcat()

in src/relax/op/tensor/manipulate.cc [181:298]


StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
  if (call->args.size() != 1) {
    ctx->ReportFatal(Diagnostic::Error(call) << "Concat op should have 1 argument");
  }
  Array<TensorStructInfo> tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]);
  if (tensor_sinfo.empty()) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "Concat op expects at least one tensor in the input Tuple. However, the "
                        "given input Tuple is empty.");
  }

  const auto* attrs = call->attrs.as<ConcatAttrs>();
  int output_ndim = attrs->axis.has_value() ? kUnknownNDim : 1;
  DataType output_dtype = DataType::Void();
  Optional<VDevice> vdev = NullOpt;
  bool shape_unknown = false;
  bool is_void_dtype = false;
  bool vdevice_unknown = false;
  std::vector<Array<PrimExpr>> shape_values;
  shape_values.reserve(tensor_sinfo.size());

  for (TensorStructInfo sinfo : tensor_sinfo) {
    // Update the output dtype.
    if (sinfo->dtype.is_void()) {
      is_void_dtype = true;
    } else if (output_dtype.is_void()) {
      output_dtype = sinfo->dtype;
    } else if (sinfo->dtype != output_dtype) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "Concat expects all input tensors to have the same dtype. However, the "
                          "input contains tensors with dtype "
                       << output_dtype << " and " << sinfo->dtype);
    }

    // Update the output ndim.
    // Todo(relax-team): revisit here for better check on if the input tensor has
    // ndim 1 when the input axis is undefined.
    if (output_ndim == kUnknownNDim) {
      output_ndim = sinfo->ndim;
    } else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "Concat expects all input tensors to have same ndim. However, the "
                          "input contains tensors with ndim "
                       << output_ndim << " and " << sinfo->ndim);
    }

    // Update the virtual device.
    if (!vdevice_unknown) {
      if (sinfo->vdevice.defined()) {
        if (!vdev.defined()) {
          vdev = sinfo->vdevice.value();
        } else if (sinfo->vdevice.value()->target.defined()) {
          // mismatch
          if (sinfo->vdevice.value() != vdev) {
            vdevice_unknown = true;
          }
        }
      }
    }

    // Update the shape values for best effort check.
    const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
    if (shape_expr != nullptr) {
      shape_values.push_back(shape_expr->values);
      continue;
    }
    shape_unknown = true;

    if (!sinfo->shape.defined()) {
      continue;
    }
    // Keep the shape value for equality check.
    ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(sinfo->shape.value()->struct_info_);
    if (shape_sinfo->values.defined()) {
      shape_values.push_back(shape_sinfo->values.value());
    }
  }

  if (is_void_dtype) {
    output_dtype = DataType::Void();
  }
  if (vdevice_unknown) {
    vdev = NullOpt;
  }

  if (output_ndim == kUnknownNDim) {
    return tensor_sinfo.size() == 1 ? tensor_sinfo[0]
                                    : TensorStructInfo(output_dtype, output_ndim, vdev);
  }

  int axis =
      attrs->axis.has_value() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()) : 0;
  // If there is only one input tensor, no action is needed.
  if (tensor_sinfo.size() == 1) {
    return tensor_sinfo[0];
  }
  if (shape_values.empty()) {
    if (!vdevice_unknown) {
      return TensorStructInfo(output_dtype, output_ndim, vdev);
    }
    return TensorStructInfo(output_dtype, output_ndim);
  }

  // As long as the there is known shape value, we will do the best effort check to ensure safety.
  Optional<Array<PrimExpr>> output_shape = CheckConcatOutputShape(call, ctx, shape_values, axis);

  if (shape_unknown || !output_shape.defined()) {
    if (!vdevice_unknown) {
      return TensorStructInfo(output_dtype, output_ndim, vdev);
    }
    return TensorStructInfo(output_dtype, output_ndim);
  } else {
    if (!vdevice_unknown) {
      return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev);
    }
    return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype);
  }
}