in include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h [457:533]
inline Value MapMhloOpToStdScalarOp<mhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Type> arg_types,
ValueRange args, OpBuilder* b) {
Type sourceType = getElementTypeOrSelf(arg_types.front());
Type targetType = getElementTypeOrSelf(result_types.front());
Type convertedSourceType = getElementTypeOrSelf(args.front());
// A boolean value is considered to be unsigned when converting to
// floating-point. Otherwise, it will become `-1`.
if ((sourceType.isInteger(/*width=*/1) || sourceType.isUnsignedInteger()) &&
mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType,
targetType)) {
return b->create<mlir::arith::UIToFPOp>(loc, result_types, args,
mlir::None);
} else if (mlir::arith::SIToFPOp::areCastCompatible(convertedSourceType,
targetType)) {
return b->create<mlir::arith::SIToFPOp>(loc, result_types, args,
mlir::None);
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
FloatType src = sourceType.cast<FloatType>();
FloatType res = targetType.cast<FloatType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::arith::TruncFOp>(loc, result_types, args,
mlir::None);
} else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::arith::ExtFOp>(loc, result_types, args,
mlir::None);
}
// No conversion is needed for the same width floats
return args.front();
}
if (targetType.isInteger(/*width=*/1)) {
// When casting to bool, we need to compare whether the value is equal to
// zero.
if (sourceType.isSignlessInteger() || sourceType.isUnsignedInteger()) {
Value zero_intval = b->create<::mlir::arith::ConstantIntOp>(
loc, 0, sourceType.cast<IntegerType>().getWidth());
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
}
return b->create<mlir::arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
args.front(), zero_intval);
} else if (sourceType.isa<FloatType>()) {
Value zero =
b->create<arith::ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
}
return b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
args.front(), zero);
}
}
if (sourceType.isa<IntegerType>() && targetType.isa<IntegerType>()) {
IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::arith::TruncIOp>(loc, result_types, args,
mlir::None);
} else if (src.getWidth() < res.getWidth()) {
// Special case boolean values, so they get casted to `1` instead of `-1`.
if (src.isUnsignedInteger() || src.getWidth() == 1) {
return b->create<mlir::arith::ExtUIOp>(loc, result_types, args,
mlir::None);
}
return b->create<mlir::arith::ExtSIOp>(loc, result_types, args,
mlir::None);
}
// No conversion is needed for the same width integers
return args.front();
}
if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType,
targetType)) {
return b->create<mlir::arith::FPToSIOp>(loc, result_types, args,
mlir::None);
}
return nullptr;
}