inline Value MapMhloOpToStdScalarOp()

in include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h [457:533]


inline Value MapMhloOpToStdScalarOp<mhlo::ConvertOp>(
    Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
    ValueRange args, OpBuilder* b) {
  Type sourceType = getElementTypeOrSelf(arg_types.front());
  Type targetType = getElementTypeOrSelf(result_types.front());
  Type convertedSourceType = getElementTypeOrSelf(args.front());

  // A boolean value is considered to be unsigned when converting to
  // floating-point. Otherwise, it will become `-1`.
  if ((sourceType.isInteger(/*width=*/1) || sourceType.isUnsignedInteger()) &&
      mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType,
                                               targetType)) {
    return b->create<mlir::arith::UIToFPOp>(loc, result_types, args,
                                            mlir::None);
  } else if (mlir::arith::SIToFPOp::areCastCompatible(convertedSourceType,
                                                      targetType)) {
    return b->create<mlir::arith::SIToFPOp>(loc, result_types, args,
                                            mlir::None);
  } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
    FloatType src = sourceType.cast<FloatType>();
    FloatType res = targetType.cast<FloatType>();
    if (src.getWidth() > res.getWidth()) {
      return b->create<mlir::arith::TruncFOp>(loc, result_types, args,
                                              mlir::None);
    } else if (src.getWidth() < res.getWidth()) {
      return b->create<mlir::arith::ExtFOp>(loc, result_types, args,
                                            mlir::None);
    }
    // No conversion is needed for the same width floats
    return args.front();
  }
  if (targetType.isInteger(/*width=*/1)) {
    // When casting to bool, we need to compare whether the value is equal to
    // zero.
    if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) {
      Value zero_intval = b->create<::mlir::arith::ConstantIntOp>(
          loc, 0, sourceType.cast<IntegerType>().getWidth());
      if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
        zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
      }
      return b->create<mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
                                            args.front(), zero_intval);
    } else if (sourceType.isa<FloatType>()) {
      Value zero =
          b->create<arith::ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
      if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
        zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
      }
      return b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
                                            args.front(), zero);
    }
  }
  if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
    IntegerType src = sourceType.cast<IntegerType>();
    IntegerType res = targetType.cast<IntegerType>();
    if (src.getWidth() > res.getWidth()) {
      return b->create<mlir::arith::TruncIOp>(loc, result_types, args,
                                              mlir::None);
    } else if (src.getWidth() < res.getWidth()) {
      // Special case boolean values, so they get casted to `1` instead of `-1`.
      if (src.isUnsignedInteger() || src.getWidth() == 1) {
        return b->create<mlir::arith::ExtUIOp>(loc, result_types, args,
                                               mlir::None);
      }
      return b->create<mlir::arith::ExtSIOp>(loc, result_types, args,
                                             mlir::None);
    }
    // No conversion is needed for the same width integers
    return args.front();
  }
  if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType,
                                               targetType)) {
    return b->create<mlir::arith::FPToSIOp>(loc, result_types, args,
                                            mlir::None);
  }
  return nullptr;
}