in lib/Dialect/mhlo/IR/hlo_ops.cc [5335:5406]
OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
auto result_ty = getType().cast<ShapedType>();
if (!result_ty.hasStaticShape()) return {};
auto direction = comparison_direction();
auto lhs_ty = getElementTypeOrSelf(lhs());
if (lhs() == rhs() && !lhs_ty.isa<FloatType>() &&
(!lhs_ty.isa<ComplexType>() ||
!lhs_ty.cast<ComplexType>().getElementType().isa<FloatType>())) {
if (direction == "LE" || direction == "EQ" || direction == "GE") {
return DenseIntElementsAttr::get(result_ty, {true});
}
return DenseIntElementsAttr::get(result_ty, {false});
}
auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
// Fold tensor<*xi1> != false to just return tensor<*xi1>
if (direction == "NE" && op_el_type.isInteger(1)) {
DenseIntElementsAttr cst_attr;
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
return rhs();
}
}
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
return lhs();
}
}
}
// Fold tensor<*xi1> == True to just return tensor<*xi1>
if (direction == "EQ" && op_el_type.isInteger(1)) {
DenseIntElementsAttr cst_attr;
if (matchPattern(lhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
return rhs();
}
}
if (matchPattern(rhs(), m_Constant(&cst_attr))) {
if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
return lhs();
}
}
}
if (!operands[0] || !operands[1]) {
return {};
}
#define COMPARE_FOLDER(Op, comparison, Func) \
if (direction == comparison) { \
if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
*this, operands)) \
return folded; \
if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
*this, operands)) \
return folded; \
}
COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
COMPARE_FOLDER(CompareOp, "LT", less);
COMPARE_FOLDER(CompareOp, "LE", less_equal);
COMPARE_FOLDER(CompareOp, "GT", greater);
COMPARE_FOLDER(CompareOp, "GE", greater_equal);
#undef COMPARE_FOLDER
return {};
}