in src/core/compile.cpp [1554:1720]
std::vector<symbolic::Constraint> Compiler::gen_constraints(
IR::NodeRef node_ref, LoopTree::TreeRef ref) const {
// Find a route to a scheduled base node
auto base_node_ref = resolved_reads.at(node_ref);
const auto &node = lt.ir.node(lt.node(ref));
if (node.op() == Operation::view) {
ASSERT(node.inputs().size() == 1);
base_node_ref = resolved_reads.at(node.inputs().at(0));
}
std::unordered_set<Symbol, Hash<Symbol>> target_syms;
std::unordered_set<Symbol, Hash<Symbol>> base_syms;
for (auto v : node.vars()) {
if (var_to_sym.count(v)) {
target_syms.insert(var_to_sym.at(v));
}
}
for (auto v : lt.ir.node(base_node_ref).vars()) {
if (var_to_sym.count(v)) {
base_syms.insert(var_to_sym.at(v));
}
}
std::vector<Constraint> constraints;
auto to_syms = [&](std::vector<IR::VarRef> vs) {
std::vector<Symbol> syms;
for (auto v : vs) {
syms.emplace_back(var_to_sym.at(v));
}
return syms;
};
// only view nodes have constraints, we
// want ones that map input vars to output
auto get_constraints = [&](IR::NodeRef node_ref) {
std::unordered_map<Symbol, Expr, Hash<Symbol>> out;
const auto &node = lt.ir.node(node_ref);
if (node.op() != Operation::view) {
return out;
}
ASSERT(node.inputs().size() == 1);
auto input_vars = lt.ir.node(node.inputs().at(0)).vars();
auto input_syms = to_set<Symbol, Hash>(to_syms(input_vars));
for (const auto &c : node.constraints()) {
if (c.first.type() != Expr::Type::symbol) {
continue;
}
auto sym = c.first.symbol();
auto expr = c.second;
if (!input_syms.count(sym)) {
continue;
}
auto skip = false;
for (auto sym : expr.symbols()) {
if (input_syms.count(sym)) {
skip = true;
break;
}
}
if (skip) {
continue;
}
out.emplace(sym, expr);
}
return out;
};
// collect initial constraints
auto vars = to_set(node.vars());
for (auto c : get_constraints(lt.node(ref))) {
constraints.emplace_back(std::make_pair(Expr(c.first), c.second));
}
// eager exit if we don't have to calculate anything
if (node_ref == base_node_ref) {
return constraints;
}
// get path from base_node_ref to ref
std::vector<IR::NodeRef> path;
{
auto cur_node_ref = base_node_ref;
auto dest_node_ref = lt.node(ref);
while (cur_node_ref != dest_node_ref) {
// path.emplace(path.begin(), cur_node_ref);
path.emplace_back(cur_node_ref);
const auto &cur_node = lt.ir.node(cur_node_ref);
auto outputs = cur_node.outputs();
if (!outputs.size()) {
break;
}
cur_node_ref = outputs.at(0);
}
}
// node by node we accumulate constraints
// starting from the base
std::unordered_map<Symbol, Expr, Hash<Symbol>> out_cs;
for (auto cur_node_ref : path) {
auto &cur_node = lt.ir.node(cur_node_ref);
auto cs = get_constraints(cur_node_ref);
for (const auto &c : cs) {
if (base_syms.count(c.first)) {
ASSERT(!out_cs.count(c.first));
out_cs.emplace(c.first, c.second);
}
for (auto &p : out_cs) {
p.second = p.second.replace(c.first, c.second);
}
}
}
// begin to coalesce constraints
std::vector<Constraint> all_constraints;
for (auto cur_node_ref : path) {
auto &cur_node = lt.ir.node(cur_node_ref);
bool add_all = false;
if (constraints.size() == 0) {
add_all = true;
}
for (auto c : cur_node.constraints()) {
all_constraints.emplace_back(c);
if (c.first.type() != Expr::Type::symbol) {
continue;
}
auto sym = c.first.symbol();
if (add_all) {
constraints.emplace_back(c);
continue;
}
bool valid = true;
for (auto cc : constraints) {
if (c.second.contains(cc.first.symbol())) {
valid = false;
}
}
if (valid) {
for (auto &cc : constraints) {
if (cc.second.contains(sym)) {
cc.second = cc.second.replace(sym, c.second);
}
}
}
}
}
auto new_constraints = unify(all_constraints);
// On occassion (e.g. windowed constraints), we pick up
// dependencies on output variables. We can safely set these to zero
// for calculation of offsets and derivatives
// (They'd go to zero anyway for the calculations)
for (auto &cc : constraints) {
for (auto sym : cc.second.symbols()) {
if (node.has_sym(sym) && vars.count(node.var(sym))) {
cc.second = cc.second.replace(sym, Expr(0)).simplify();
}
}
}
constraints.clear();
for (auto c : out_cs) {
constraints.emplace_back(std::make_pair(Expr(c.first), c.second));
}
return constraints;
}