in src/relax/ir/dataflow_expr_rewriter.cc [338:522]
Map<Var, Expr> TupleRewriterNode::GenerateVariableRewrites(const Array<Binding>& bindings) const {
Map<Var, Expr> rewrites;
Map<Var, Expr> binding_lookup;
std::vector<VarInfo> info_vec;
std::unordered_map<Var, size_t> binding_index_lookup;
// Initialize a vector of indices, each of which corresponds to a
// potential match for a tuple element.
//
// \param tuple_index_of_current_expr The index for the most recent
// binding.
//
// \param indices An output vector, into which indices will be
// generated.
//
// \returns bool True if the indices could be initialized to a
// potential match. False, otherwise.
auto initialize_indices = [&](size_t tuple_index_of_current_expr,
std::vector<size_t>& indices) -> bool {
if (!info_vec.back().matches[tuple_index_of_current_expr]) {
return false;
}
indices = std::vector<size_t>(patterns.size(), info_vec.size());
indices[tuple_index_of_current_expr] = info_vec.size() - 1;
for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) {
size_t i = indices.size() - i_rev - 1;
if (indices[i] == info_vec.size() - 1) {
continue;
}
auto binding_index = [&]() -> std::optional<size_t> {
if (indices[i] == info_vec.size() - 1) {
return info_vec.size() - 1;
}
for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) {
size_t j = info_vec.size() - j_rev - 1;
if (info_vec[j].matches[i] && !info_vec[j].used &&
std::all_of(indices.begin() + (j + 1), indices.end(),
[j](size_t prev_binding_index) { return j != prev_binding_index; })) {
return j;
}
}
return std::nullopt;
}();
if (binding_index.has_value()) {
indices[i] = binding_index.value();
} else {
return false;
}
}
return true;
};
auto decrement_indices = [&](std::vector<size_t>& indices) -> bool {
ICHECK_EQ(indices.size(), patterns.size());
// Step 1, find the first index that can be decremented, while
// still generating a valid set of indices.
size_t i_forward;
for (i_forward = 0; i_forward < indices.size(); i_forward++) {
if (indices[i_forward] == info_vec.size() - 1) {
continue;
}
bool found_valid = false;
size_t& index = indices[i_forward];
while (index) {
index--;
if (info_vec[index].matches[i_forward] && !info_vec[index].used &&
std::all_of(
indices.begin() + (i_forward + 1), indices.end(),
[index](size_t later_binding_index) { return index != later_binding_index; })) {
found_valid = true;
break;
}
}
if (found_valid) {
break;
}
}
// Step 2, if we reached the end, then all indices were
// decremented to zero without finding anything. Return false to
// indicate that we've reached the end.
if (i_forward == indices.size()) {
return false;
}
// Step 3, refill all indices that were decremented to zero before from 0 to
for (size_t i = 0; i < i_forward; i++) {
size_t i_backward = i_forward - (i + 1);
if (indices[i_backward] == info_vec.size() - 1) {
continue;
}
auto binding_index = [&]() -> std::optional<size_t> {
for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) {
size_t j = info_vec.size() - j_rev - 1;
if (info_vec[j].matches[i_backward] && !info_vec[j].used &&
std::all_of(indices.begin() + (j + 1), indices.end(),
[j](size_t prev_binding_index) { return j != prev_binding_index; })) {
return j;
}
}
return std::nullopt;
}();
if (binding_index.has_value()) {
indices[i_backward] = binding_index.value();
} else {
return false;
}
}
return true;
};
for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) {
const auto& binding = bindings[i_binding];
auto expr = GetBoundValue(binding);
binding_index_lookup[binding->var] = i_binding;
info_vec.push_back(VarInfo{
binding->var,
expr,
patterns.Map(
[&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }),
std::unordered_set<Var>(),
false,
});
auto new_match = [&]() -> std::optional<std::pair<std::vector<size_t>, std::vector<Expr>>> {
std::vector<size_t> indices;
for (size_t i = 0; i < patterns.size(); i++) {
if (initialize_indices(patterns.size() - i - 1, indices)) {
do {
if (auto match = TryMatchByBindingIndex(info_vec, indices)) {
return std::pair{indices, match.value()};
}
} while (decrement_indices(indices));
}
}
return std::nullopt;
}();
if (new_match) {
const auto& [indices, exprs] = new_match.value();
ICHECK_EQ(indices.size(), exprs.size());
for (size_t i = 0; i < indices.size(); i++) {
ICHECK_LT(indices[i], info_vec.size());
auto& info = info_vec[indices[i]];
ICHECK(!info.used) << "InternalError: "
<< "Produced multiple replacements for variable " << info.var;
rewrites.Set(info.var, exprs[i]);
binding_lookup.erase(info.var);
info.used = true;
}
} else {
binding_lookup.Set(binding->var, expr);
}
for (const auto& prev_var : FreeVars(expr)) {
if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) {
info_vec[it->second].downstream_usage.insert(binding->var);
}
}
}
return rewrites;
}