in src/relax/ir/dataflow_matcher.cc [194:302]
bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) {
auto expr = UnwrapBindings(expr0, var2val_);
// utilities
auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
if (op) {
if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
return expr_pattern->expr.as<OpNode>();
}
}
return nullptr;
};
auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) {
if (const auto* op_node = get_op_node(op)) {
if (op_node->name == op_type) {
return true;
}
}
return false;
};
auto is_expr_op = [](const Expr& expr, std::string op_type) {
if (const auto* call_node = expr.as<CallNode>()) {
if (const auto* op_node = call_node->op.as<OpNode>()) {
if (op_node->name == op_type) {
return true;
}
}
}
return false;
};
// logic
auto watermark = matched_nodes_.size();
if (const auto* call_node = expr.as<CallNode>()) {
auto matches_op = VisitDFPattern(op->op, call_node->op);
if (matches_op) {
auto watermark2 = matched_nodes_.size();
auto match_args = [this, &watermark2](const Array<DFPattern>& pattern_args, auto expr_begin,
auto expr_end) {
bool matches = true;
auto pattern_it = pattern_args.begin();
auto expr_it = expr_begin;
if (pattern_args.defined()) {
while (matches && pattern_it != pattern_args.end())
matches &= VisitDFPattern(*(pattern_it++), *(expr_it++));
}
if (!matches) ClearMap(watermark2);
return matches;
};
const size_t n_arg_pattern = op->args.size();
const size_t n_arg_expr = call_node->args.size();
// if allow variable args, #pattern must >= #expr.
if (op->varg_default_wildcard && n_arg_expr < n_arg_pattern) return false;
// if variable args are not allowed, #pattern must == #expr.
if (!op->varg_default_wildcard && n_arg_expr != n_arg_pattern) return false;
// Standard case
if (match_args(op->args, call_node->args.begin(), call_node->args.end())) return true;
// Commutative Matching.
if (const OpNode* op_node = call_node->op.as<OpNode>()) {
if ((op_node->name == "relax.add") || (op_node->name == "relax.multiply")) {
if (match_args(op->args, call_node->args.rbegin(), call_node->args.rend())) {
return true;
}
}
}
} else {
ClearMap(watermark);
// associate divide/multiply
if (is_pattern_op(op, "relax.divide")) {
if (const auto* arg_node = op->args[0].as<CallPatternNode>()) {
if (is_pattern_op(arg_node, "relax.multiply") && is_expr_op(expr, "relax.multiply") &&
(is_expr_op(call_node->args[0], "relax.divide") ||
is_expr_op(call_node->args[1], "relax.divide"))) {
bool out = false;
for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]});
auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div});
out = VisitDFPattern(mul, expr);
if (out) {
return true;
} else {
ClearMap(watermark);
}
}
return out;
}
}
}
if (is_pattern_op(op, "relax.multiply")) {
// associate multiply/divide
for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
if (auto* arg_node = op->args[arg_id].as<CallPatternNode>()) {
if (is_pattern_op(arg_node, "relax.divide") && is_expr_op(expr, "relax.divide") &&
(is_expr_op(call_node->args[0], "relax.multiply") ||
is_expr_op(call_node->args[1], "relax.multiply"))) {
auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]});
auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]});
return VisitDFPattern(div, expr);
}
}
}
}
}
}
return false;
}