StructInfo InferStructInfoEwiseFMA()

in src/relax/op/tensor/ternary.cc [30:114]


StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) {
  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
  TensorStructInfo t1 = input_sinfo[0];
  TensorStructInfo t2 = input_sinfo[1];
  TensorStructInfo t3 = input_sinfo[2];

  int ndim = kUnknownNDim;
  if (!t1->IsUnknownNdim()) {
    ndim = t1->ndim;
  }
  if (!t2->IsUnknownNdim()) {
    if (ndim == kUnknownNDim) {
      ndim = t2->ndim;
    } else if (t2->ndim != ndim) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "The 3 arguments of EwiseFMA must have the same number of dimensions");
    }
  }
  if (!t3->IsUnknownNdim()) {
    if (ndim == kUnknownNDim) {
      ndim = t3->ndim;
    } else if (t3->ndim != ndim) {
      ctx->ReportFatal(Diagnostic::Error(call)
                       << "The 3 arguments of EwiseFMA must have the same number of dimensions");
    }
  }

  DataType output_dtype;
  if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) {
    output_dtype = DataType::Void();
  } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) {
    ctx->ReportFatal(Diagnostic::Error(call)
                     << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype
                     << " must be equal for EwiseFMA");
  } else {
    output_dtype = t1->dtype;
  }

  VDevice vdev = VDevice();
  for (int i = 0; i < 3; ++i) {
    if (input_sinfo[i]->vdevice.defined()) {
      if (!vdev.defined()) {
        vdev = input_sinfo[i]->vdevice.value();
      } else if (input_sinfo[i]->vdevice.value()->target.defined()) {
        // mismatch
        if (input_sinfo[i]->vdevice.value() != vdev) {
          vdev = VDevice();
          break;
        }
      }
    }
  }

  auto* s1 = t1->shape.as<ShapeExprNode>();
  auto* s2 = t2->shape.as<ShapeExprNode>();
  auto* s3 = t3->shape.as<ShapeExprNode>();
  arith::Analyzer* analyzer = ctx->GetAnalyzer();
  if (s1 && s2 && s3) {
    Array<PrimExpr> output_shape;
    for (int i = 0; i < ndim; ++i) {
      PrimExpr dim1 = s1->values[i];
      PrimExpr dim2 = s2->values[i];
      PrimExpr dim3 = s3->values[i];
      if (analyzer->CanProveEqual(dim1, dim2) && analyzer->CanProveEqual(dim2, dim3)) {
        output_shape.push_back(dim1);
      } else {
        ctx->ReportFatal(Diagnostic::Error(call)
                         << "The 3 arguments of EwiseFMA must have the same shape");
      }
    }
    if (vdev.defined()) {
      return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev);
    }
    return TensorStructInfo(ShapeExpr(output_shape), output_dtype);
  } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) {
    if (vdev.defined()) {
      return TensorStructInfo(t1->shape.value(), output_dtype, vdev);
    }
    return TensorStructInfo(t1->shape.value(), output_dtype);
  }
  if (vdev.defined()) {
    return TensorStructInfo(output_dtype, ndim, vdev);
  }
  return TensorStructInfo(output_dtype, ndim);
}