in src/arith/rewrite_simplify.cc [2270:2348]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
if (op == nullptr) return ret;
if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
return op->args[0];
} else if (op->op.same_as(tir::builtin::shift_right())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] >> op->args[1];
}
} else if (op->op.same_as(tir::builtin::shift_left())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
// the operator overload will eagerly constant fold.
return op->args[0] << op->args[1];
}
} else if (op->op.same_as(Op::Get("tir.ceil"))) {
PrimExpr ceil_arg = op->args[0];
if (auto arg_int = op->args[0].as<IntImmNode>()) {
return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
} else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
} else if (auto arg_call = ceil_arg.as<CallNode>()) {
// ceil(log2(cast(n,"float64"))) is used as the implementation of
// topi.math.ceil_log2, and appears in iteration bounds.
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
PrimExpr log_arg = arg_call->args[0];
if (auto as_float = log_arg.as<FloatImmNode>()) {
// ceil(log2(n)) can be simplified, and should produce the
// same integer result regardless of the target's rounding
// conventions.
return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
}
}
}
} else if (op->op.same_as(Op::Get("tir.clz"))) {
if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
int bits = arg_int->dtype.bits();
if (arg_int->value == 0) return make_const(op->dtype, bits);
for (int i = bits - 1; i >= 0; --i) {
if ((int64_t(1) << i) & arg_int->value) {
return IntImm(op->dtype, bits - i - 1);
}
}
LOG(FATAL) << "Should not reach here";
}
}
if (op->op.same_as(tir::builtin::likely())) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
return match.value();
}
}
if (op->op.same_as(tir::builtin::if_then_else())) {
// Simplify nested if_then_else
// if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr }
// => if (cond && inner_cond) { inner_then_expr } else { else_expr }
const PrimExpr& cond = op->args[0];
const PrimExpr& then_expr = op->args[1];
const PrimExpr& else_expr = op->args[2];
const CallNode* inner_call = then_expr.as<CallNode>();
if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) {
const PrimExpr& inner_cond = inner_call->args[0];
const PrimExpr& inner_then_expr = inner_call->args[1];
const PrimExpr& inner_else_expr = inner_call->args[2];
// Only check constant cases to avoid recursion
if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
analyzer_->CanProve(inner_else_expr == else_expr)) {
return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
}
}
}
return ret;
}