in src/relax/op/tensor/ternary.cc [30:114]
StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo t1 = input_sinfo[0];
TensorStructInfo t2 = input_sinfo[1];
TensorStructInfo t3 = input_sinfo[2];
int ndim = kUnknownNDim;
if (!t1->IsUnknownNdim()) {
ndim = t1->ndim;
}
if (!t2->IsUnknownNdim()) {
if (ndim == kUnknownNDim) {
ndim = t2->ndim;
} else if (t2->ndim != ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The 3 arguments of EwiseFMA must have the same number of dimensions");
}
}
if (!t3->IsUnknownNdim()) {
if (ndim == kUnknownNDim) {
ndim = t3->ndim;
} else if (t3->ndim != ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The 3 arguments of EwiseFMA must have the same number of dimensions");
}
}
DataType output_dtype;
if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) {
output_dtype = DataType::Void();
} else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype
<< " must be equal for EwiseFMA");
} else {
output_dtype = t1->dtype;
}
VDevice vdev = VDevice();
for (int i = 0; i < 3; ++i) {
if (input_sinfo[i]->vdevice.defined()) {
if (!vdev.defined()) {
vdev = input_sinfo[i]->vdevice.value();
} else if (input_sinfo[i]->vdevice.value()->target.defined()) {
// mismatch
if (input_sinfo[i]->vdevice.value() != vdev) {
vdev = VDevice();
break;
}
}
}
}
auto* s1 = t1->shape.as<ShapeExprNode>();
auto* s2 = t2->shape.as<ShapeExprNode>();
auto* s3 = t3->shape.as<ShapeExprNode>();
arith::Analyzer* analyzer = ctx->GetAnalyzer();
if (s1 && s2 && s3) {
Array<PrimExpr> output_shape;
for (int i = 0; i < ndim; ++i) {
PrimExpr dim1 = s1->values[i];
PrimExpr dim2 = s2->values[i];
PrimExpr dim3 = s3->values[i];
if (analyzer->CanProveEqual(dim1, dim2) && analyzer->CanProveEqual(dim2, dim3)) {
output_shape.push_back(dim1);
} else {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The 3 arguments of EwiseFMA must have the same shape");
}
}
if (vdev.defined()) {
return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
}
return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
} else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) {
if (vdev.defined()) {
return TensorStructInfo(t1->shape.value(), output_dtype, vdev);
}
return TensorStructInfo(t1->shape.value(), output_dtype);
}
if (vdev.defined()) {
return TensorStructInfo(output_dtype, ndim, vdev);
}
return TensorStructInfo(output_dtype, ndim);
}