OpFoldResult CompareOp::fold()

in lib/Dialect/mhlo/IR/hlo_ops.cc [5335:5406]


OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
  auto result_ty = getType().cast<ShapedType>();
  if (!result_ty.hasStaticShape()) return {};

  auto direction = comparison_direction();
  auto lhs_ty = getElementTypeOrSelf(lhs());
  if (lhs() == rhs() && !lhs_ty.isa<FloatType>() &&
      (!lhs_ty.isa<ComplexType>() ||
       !lhs_ty.cast<ComplexType>().getElementType().isa<FloatType>())) {
    if (direction == "LE" || direction == "EQ" || direction == "GE") {
      return DenseIntElementsAttr::get(result_ty, {true});
    }
    return DenseIntElementsAttr::get(result_ty, {false});
  }

  auto op_el_type = lhs().getType().cast<ShapedType>().getElementType();
  // Fold tensor<*xi1> != false to just return tensor<*xi1>
  if (direction == "NE" && op_el_type.isInteger(1)) {
    DenseIntElementsAttr cst_attr;
    if (matchPattern(lhs(), m_Constant(&cst_attr))) {
      if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
        return rhs();
      }
    }

    if (matchPattern(rhs(), m_Constant(&cst_attr))) {
      if (cst_attr.isSplat() && !cst_attr.getSplatValue<bool>()) {
        return lhs();
      }
    }
  }

  // Fold tensor<*xi1> == True to just return tensor<*xi1>
  if (direction == "EQ" && op_el_type.isInteger(1)) {
    DenseIntElementsAttr cst_attr;
    if (matchPattern(lhs(), m_Constant(&cst_attr))) {
      if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
        return rhs();
      }
    }

    if (matchPattern(rhs(), m_Constant(&cst_attr))) {
      if (cst_attr.isSplat() && cst_attr.getSplatValue<bool>()) {
        return lhs();
      }
    }
  }

  if (!operands[0] || !operands[1]) {
    return {};
  }

#define COMPARE_FOLDER(Op, comparison, Func)                                \
  if (direction == comparison) {                                            \
    if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
            *this, operands))                                               \
      return folded;                                                        \
    if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>(   \
            *this, operands))                                               \
      return folded;                                                        \
  }

  COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
  COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
  COMPARE_FOLDER(CompareOp, "LT", less);
  COMPARE_FOLDER(CompareOp, "LE", less_equal);
  COMPARE_FOLDER(CompareOp, "GT", greater);
  COMPARE_FOLDER(CompareOp, "GE", greater_equal);
#undef COMPARE_FOLDER

  return {};
}