in src/core/symbolic.cpp [659:911]
std::vector<Constraint> unify(std::vector<Constraint> constraints_) {
// 1. Get all symbols
std::unordered_set<Symbol, Hash<Symbol>> all_symbols;
auto get_all_syms = [](const Expr& expr) {
std::vector<Symbol> syms;
expr.walk([&](const Expr& e) {
if (e.type() == Expr::Type::symbol) {
syms.emplace_back(e.symbol());
}
return e;
});
return syms;
};
for (const auto& c : constraints_) {
for (auto& sym : get_all_syms(c.first)) {
all_symbols.insert(sym);
}
for (auto& sym : get_all_syms(c.second)) {
all_symbols.insert(sym);
}
}
// 2. Symbolicate all size expressions (map symbol -> Size(symbol))
std::unordered_map<Symbol, Symbol, Hash<Symbol>> size_sym_map;
for (const auto& sym : all_symbols) {
size_sym_map[sym] = Symbol(sym.name() + "_size");
}
// 3. Remap size expressions in the constraint list
std::vector<Constraint> constraints;
for (const auto& c : constraints_) {
auto lhs = c.first;
auto rhs = c.second;
for (auto& p : size_sym_map) {
auto size_expr = Expr::size(p.first);
lhs = lhs.replace(size_expr, p.second);
rhs = rhs.replace(size_expr, p.second);
}
constraints.emplace_back(lhs, rhs);
}
// 4. Collect size-only constraints (no symbols)
std::unordered_map<Symbol, std::unordered_set<Expr, Hash<Expr>>, Hash<Symbol>>
size_constraints;
for (const auto& c : constraints) {
bool size_only = true;
for (const auto& sym : all_symbols) {
if (c.first.contains(sym) || c.second.contains(sym)) {
size_only = false;
break;
}
}
if (!size_only) {
continue;
}
for (const auto& p : size_sym_map) {
auto sym = p.first;
auto size_sym = p.second;
if (!can_isolate(c, size_sym)) {
continue;
}
const auto& expr = isolate(c, size_sym).second;
size_constraints[sym].insert(expr);
}
}
for (auto& p : size_constraints) {
if (p.second.size() > 1) {
const auto& exprs = p.second;
for (const auto& expr : exprs) {
if (expr.type() == Expr::Type::value) {
size_constraints[p.first] = {expr};
break;
}
}
}
}
// 5. Simplify the size constraints
auto sized = [&](Symbol sym) {
if (!size_constraints.count(sym)) {
return false;
}
const auto& exprs = size_constraints.at(sym);
if (exprs.size() != 1) {
return false;
}
const auto& expr = *exprs.begin();
if (expr.type() == Expr::Type::value) {
return true;
}
return false;
};
auto sized_syms = [&]() {
std::vector<std::pair<Symbol, Expr>> sizes;
for (const auto& s : size_constraints) {
auto sym = s.first;
if (sized(sym)) {
auto expr = *s.second.begin();
sizes.emplace_back(sym, expr.simplify());
}
}
return sizes;
};
auto simplify_all_sizes = [&]() -> bool {
bool changed = false;
for (auto& p : size_constraints) {
auto exprs = p.second;
p.second.clear();
for (const auto& expr_ : exprs) {
auto expr = expr_;
for (auto& s : sized_syms()) {
expr = expr.replace(size_sym_map.at(s.first), s.second);
}
p.second.insert(expr.simplify());
}
if (exprs.size() == p.second.size()) {
for (const auto& expr : exprs) {
if (!p.second.count(expr)) {
changed = true;
}
}
} else {
changed = true;
}
}
return changed;
};
auto resolve_values = [&]() {
for (auto& p : size_constraints) {
bool value_set = false;
auto sym = p.first;
auto size_exprs = p.second;
for (auto& e : size_exprs) {
if (e.type() == Expr::Type::value) {
ASSERT(!value_set)
<< "size of " << sym.name()
<< " set multiple times to different values"
<< " (new value: " << e.dump() << " old:"
<< " " << size_constraints[sym].begin()->dump() << ")";
size_constraints[sym].clear();
size_constraints[sym].insert(e);
value_set = true;
}
}
}
};
{
int limit = 1000;
while (simplify_all_sizes() && (limit--) > 0) {
resolve_values();
}
}
// 6. Derive all indexing constraints
std::unordered_map<Symbol, std::unordered_set<Expr, Hash<Expr>>, Hash<Symbol>>
index_constraints;
// collect all indexing and size constraints and create a size(sym)->sym map
for (const auto& c : constraints) {
for (const auto& sym : all_symbols) {
if (can_isolate(c, sym)) {
const auto& expr = isolate(c, sym).second;
index_constraints[sym].insert(expr);
}
}
}
// 7. Derive unknown size constraints from index constraints
auto derive_size_expressions = [&]() {
for (const auto& sym : all_symbols) {
if (sized(sym)) {
continue;
}
if (!index_constraints.count(sym)) {
continue;
}
// derived sized functions
// x = y + k -->
// |x| - 1 = |y| - 1 + |k| - 1
// |x| = |y| - 1 + |k| - 1 + 1
for (const auto& expr : index_constraints.at(sym)) {
auto size_expr = expr.walk([&](const Expr& e) {
if (e.type() == Expr::Type::symbol &&
size_sym_map.count(e.symbol())) {
return Expr(size_sym_map.at(e.symbol())) - Expr(1);
}
return e;
});
// if we've already had our sizes bound by previous iterations
// of unification, we can skip this step entirely
if (size_constraints.count(sym)) {
continue;
}
size_constraints[sym].insert(size_expr + Expr(1));
}
}
};
derive_size_expressions();
{
int limit = 1000;
while (simplify_all_sizes() && (limit--) > 0) {
}
}
// 8. All done, take the maximum if there are multiple size constraints
std::vector<std::pair<Expr, Expr>> output_constraints;
auto map_to_size_expr = [&](const Expr& e) {
auto expr = e;
for (auto& s : size_sym_map) {
expr = expr.replace(s.second, Expr::size(s.first));
}
return expr.simplify();
};
for (auto& p : size_constraints) {
if (p.second.size() == 1) {
output_constraints.emplace_back(Expr::size(p.first),
map_to_size_expr(*p.second.begin()));
continue;
}
auto max_expr = *p.second.begin();
for (auto& expr : p.second) {
max_expr = Expr::max(max_expr, expr);
}
output_constraints.emplace_back(Expr::size(p.first),
map_to_size_expr(max_expr));
}
for (auto& p : index_constraints) {
auto sym = p.first;
std::unordered_set<Expr, Hash<Expr>> mapped_exprs;
for (auto expr : p.second) {
mapped_exprs.insert(map_to_size_expr(expr));
}
for (auto expr : mapped_exprs) {
output_constraints.emplace_back(sym, expr);
}
}
return output_constraints;
}