PrimExpr RewriteSimplifier::Impl::VisitExpr_()

in src/arith/rewrite_simplify.cc [2270:2348]


PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
  // add condition context to if_then_else
  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
  op = ret.as<CallNode>();
  if (op == nullptr) return ret;

  if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
    return op->args[0];
  } else if (op->op.same_as(tir::builtin::shift_right())) {
    if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
      // the operator overload will eagerly constant fold.
      return op->args[0] >> op->args[1];
    }
  } else if (op->op.same_as(tir::builtin::shift_left())) {
    if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
      // the operator overload will eagerly constant fold.
      return op->args[0] << op->args[1];
    }
  } else if (op->op.same_as(Op::Get("tir.ceil"))) {
    PrimExpr ceil_arg = op->args[0];
    if (auto arg_int = op->args[0].as<IntImmNode>()) {
      return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
    } else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
      return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
    } else if (auto arg_call = ceil_arg.as<CallNode>()) {
      // ceil(log2(cast(n,"float64"))) is used as the implementation of
      // topi.math.ceil_log2, and appears in iteration bounds.
      if (arg_call->op.same_as(Op::Get("tir.log2"))) {
        PrimExpr log_arg = arg_call->args[0];
        if (auto as_float = log_arg.as<FloatImmNode>()) {
          // ceil(log2(n)) can be simplified, and should produce the
          // same integer result regardless of the target's rounding
          // conventions.
          return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
        }
      }
    }
  } else if (op->op.same_as(Op::Get("tir.clz"))) {
    if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
      int bits = arg_int->dtype.bits();
      if (arg_int->value == 0) return make_const(op->dtype, bits);
      for (int i = bits - 1; i >= 0; --i) {
        if ((int64_t(1) << i) & arg_int->value) {
          return IntImm(op->dtype, bits - i - 1);
        }
      }
      LOG(FATAL) << "Should not reach here";
    }
  }

  if (op->op.same_as(tir::builtin::likely())) {
    // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
    if (auto match = TryMatchLiteralConstraint(op->args[0])) {
      return match.value();
    }
  }

  if (op->op.same_as(tir::builtin::if_then_else())) {
    // Simplify nested if_then_else
    // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr }
    // => if (cond && inner_cond) { inner_then_expr } else { else_expr }
    const PrimExpr& cond = op->args[0];
    const PrimExpr& then_expr = op->args[1];
    const PrimExpr& else_expr = op->args[2];
    const CallNode* inner_call = then_expr.as<CallNode>();
    if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) {
      const PrimExpr& inner_cond = inner_call->args[0];
      const PrimExpr& inner_then_expr = inner_call->args[1];
      const PrimExpr& inner_else_expr = inner_call->args[2];
      // Only check constant cases to avoid recursion
      if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
          analyzer_->CanProve(inner_else_expr == else_expr)) {
        return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
      }
    }
  }

  return ret;
}