std::vector unify()

in src/core/symbolic.cpp [659:911]


std::vector<Constraint> unify(std::vector<Constraint> constraints_) {
  // 1. Get all symbols
  std::unordered_set<Symbol, Hash<Symbol>> all_symbols;

  auto get_all_syms = [](const Expr& expr) {
    std::vector<Symbol> syms;
    expr.walk([&](const Expr& e) {
      if (e.type() == Expr::Type::symbol) {
        syms.emplace_back(e.symbol());
      }
      return e;
    });
    return syms;
  };

  for (const auto& c : constraints_) {
    for (auto& sym : get_all_syms(c.first)) {
      all_symbols.insert(sym);
    }
    for (auto& sym : get_all_syms(c.second)) {
      all_symbols.insert(sym);
    }
  }

  // 2. Symbolicate all size expressions (map symbol -> Size(symbol))
  std::unordered_map<Symbol, Symbol, Hash<Symbol>> size_sym_map;

  for (const auto& sym : all_symbols) {
    size_sym_map[sym] = Symbol(sym.name() + "_size");
  }

  // 3. Remap size expressions in the constraint list
  std::vector<Constraint> constraints;

  for (const auto& c : constraints_) {
    auto lhs = c.first;
    auto rhs = c.second;
    for (auto& p : size_sym_map) {
      auto size_expr = Expr::size(p.first);
      lhs = lhs.replace(size_expr, p.second);
      rhs = rhs.replace(size_expr, p.second);
    }
    constraints.emplace_back(lhs, rhs);
  }

  // 4. Collect size-only constraints (no symbols)
  std::unordered_map<Symbol, std::unordered_set<Expr, Hash<Expr>>, Hash<Symbol>>
      size_constraints;

  for (const auto& c : constraints) {
    bool size_only = true;
    for (const auto& sym : all_symbols) {
      if (c.first.contains(sym) || c.second.contains(sym)) {
        size_only = false;
        break;
      }
    }
    if (!size_only) {
      continue;
    }
    for (const auto& p : size_sym_map) {
      auto sym = p.first;
      auto size_sym = p.second;
      if (!can_isolate(c, size_sym)) {
        continue;
      }
      const auto& expr = isolate(c, size_sym).second;
      size_constraints[sym].insert(expr);
    }
  }

  for (auto& p : size_constraints) {
    if (p.second.size() > 1) {
      const auto& exprs = p.second;
      for (const auto& expr : exprs) {
        if (expr.type() == Expr::Type::value) {
          size_constraints[p.first] = {expr};
          break;
        }
      }
    }
  }

  // 5. Simplify the size constraints
  auto sized = [&](Symbol sym) {
    if (!size_constraints.count(sym)) {
      return false;
    }
    const auto& exprs = size_constraints.at(sym);
    if (exprs.size() != 1) {
      return false;
    }
    const auto& expr = *exprs.begin();
    if (expr.type() == Expr::Type::value) {
      return true;
    }
    return false;
  };

  auto sized_syms = [&]() {
    std::vector<std::pair<Symbol, Expr>> sizes;
    for (const auto& s : size_constraints) {
      auto sym = s.first;
      if (sized(sym)) {
        auto expr = *s.second.begin();
        sizes.emplace_back(sym, expr.simplify());
      }
    }
    return sizes;
  };

  auto simplify_all_sizes = [&]() -> bool {
    bool changed = false;
    for (auto& p : size_constraints) {
      auto exprs = p.second;
      p.second.clear();
      for (const auto& expr_ : exprs) {
        auto expr = expr_;
        for (auto& s : sized_syms()) {
          expr = expr.replace(size_sym_map.at(s.first), s.second);
        }
        p.second.insert(expr.simplify());
      }
      if (exprs.size() == p.second.size()) {
        for (const auto& expr : exprs) {
          if (!p.second.count(expr)) {
            changed = true;
          }
        }
      } else {
        changed = true;
      }
    }
    return changed;
  };

  auto resolve_values = [&]() {
    for (auto& p : size_constraints) {
      bool value_set = false;
      auto sym = p.first;
      auto size_exprs = p.second;
      for (auto& e : size_exprs) {
        if (e.type() == Expr::Type::value) {
          ASSERT(!value_set)
              << "size of " << sym.name()
              << " set multiple times to different values"
              << " (new value: " << e.dump() << " old:"
              << " " << size_constraints[sym].begin()->dump() << ")";
          size_constraints[sym].clear();
          size_constraints[sym].insert(e);
          value_set = true;
        }
      }
    }
  };

  {
    int limit = 1000;
    while (simplify_all_sizes() && (limit--) > 0) {
      resolve_values();
    }
  }

  // 6. Derive all indexing constraints
  std::unordered_map<Symbol, std::unordered_set<Expr, Hash<Expr>>, Hash<Symbol>>
      index_constraints;

  // collect all indexing and size constraints and create a size(sym)->sym map
  for (const auto& c : constraints) {
    for (const auto& sym : all_symbols) {
      if (can_isolate(c, sym)) {
        const auto& expr = isolate(c, sym).second;
        index_constraints[sym].insert(expr);
      }
    }
  }

  // 7. Derive unknown size constraints from index constraints
  auto derive_size_expressions = [&]() {
    for (const auto& sym : all_symbols) {
      if (sized(sym)) {
        continue;
      }
      if (!index_constraints.count(sym)) {
        continue;
      }
      // derived sized functions
      // x = y + k -->
      //  |x| - 1 = |y| - 1 + |k| - 1
      //  |x| = |y| - 1 + |k| - 1 + 1
      for (const auto& expr : index_constraints.at(sym)) {
        auto size_expr = expr.walk([&](const Expr& e) {
          if (e.type() == Expr::Type::symbol &&
              size_sym_map.count(e.symbol())) {
            return Expr(size_sym_map.at(e.symbol())) - Expr(1);
          }
          return e;
        });
        // if we've already had our sizes bound by previous iterations
        // of unification, we can skip this step entirely
        if (size_constraints.count(sym)) {
          continue;
        }
        size_constraints[sym].insert(size_expr + Expr(1));
      }
    }
  };

  derive_size_expressions();
  {
    int limit = 1000;
    while (simplify_all_sizes() && (limit--) > 0) {
    }
  }

  // 8. All done, take the maximum if there are multiple size constraints
  std::vector<std::pair<Expr, Expr>> output_constraints;

  auto map_to_size_expr = [&](const Expr& e) {
    auto expr = e;
    for (auto& s : size_sym_map) {
      expr = expr.replace(s.second, Expr::size(s.first));
    }
    return expr.simplify();
  };

  for (auto& p : size_constraints) {
    if (p.second.size() == 1) {
      output_constraints.emplace_back(Expr::size(p.first),
                                      map_to_size_expr(*p.second.begin()));
      continue;
    }
    auto max_expr = *p.second.begin();
    for (auto& expr : p.second) {
      max_expr = Expr::max(max_expr, expr);
    }
    output_constraints.emplace_back(Expr::size(p.first),
                                    map_to_size_expr(max_expr));
  }

  for (auto& p : index_constraints) {
    auto sym = p.first;
    std::unordered_set<Expr, Hash<Expr>> mapped_exprs;
    for (auto expr : p.second) {
      mapped_exprs.insert(map_to_size_expr(expr));
    }
    for (auto expr : mapped_exprs) {
      output_constraints.emplace_back(sym, expr);
    }
  }

  return output_constraints;
}