std::vector Compiler::gen_constraints()

in src/core/compile.cpp [1554:1720]


std::vector<symbolic::Constraint> Compiler::gen_constraints(
    IR::NodeRef node_ref, LoopTree::TreeRef ref) const {
  // Find a route to a scheduled base node
  auto base_node_ref = resolved_reads.at(node_ref);

  const auto &node = lt.ir.node(lt.node(ref));
  if (node.op() == Operation::view) {
    ASSERT(node.inputs().size() == 1);
    base_node_ref = resolved_reads.at(node.inputs().at(0));
  }

  std::unordered_set<Symbol, Hash<Symbol>> target_syms;
  std::unordered_set<Symbol, Hash<Symbol>> base_syms;
  for (auto v : node.vars()) {
    if (var_to_sym.count(v)) {
      target_syms.insert(var_to_sym.at(v));
    }
  }
  for (auto v : lt.ir.node(base_node_ref).vars()) {
    if (var_to_sym.count(v)) {
      base_syms.insert(var_to_sym.at(v));
    }
  }

  std::vector<Constraint> constraints;

  auto to_syms = [&](std::vector<IR::VarRef> vs) {
    std::vector<Symbol> syms;
    for (auto v : vs) {
      syms.emplace_back(var_to_sym.at(v));
    }
    return syms;
  };

  // only view nodes have constraints, we
  // want ones that map input vars to output
  auto get_constraints = [&](IR::NodeRef node_ref) {
    std::unordered_map<Symbol, Expr, Hash<Symbol>> out;
    const auto &node = lt.ir.node(node_ref);
    if (node.op() != Operation::view) {
      return out;
    }
    ASSERT(node.inputs().size() == 1);
    auto input_vars = lt.ir.node(node.inputs().at(0)).vars();
    auto input_syms = to_set<Symbol, Hash>(to_syms(input_vars));
    for (const auto &c : node.constraints()) {
      if (c.first.type() != Expr::Type::symbol) {
        continue;
      }
      auto sym = c.first.symbol();
      auto expr = c.second;
      if (!input_syms.count(sym)) {
        continue;
      }
      auto skip = false;
      for (auto sym : expr.symbols()) {
        if (input_syms.count(sym)) {
          skip = true;
          break;
        }
      }
      if (skip) {
        continue;
      }
      out.emplace(sym, expr);
    }
    return out;
  };

  // collect initial constraints
  auto vars = to_set(node.vars());
  for (auto c : get_constraints(lt.node(ref))) {
    constraints.emplace_back(std::make_pair(Expr(c.first), c.second));
  }

  // eager exit if we don't have to calculate anything
  if (node_ref == base_node_ref) {
    return constraints;
  }

  // get path from base_node_ref to ref
  std::vector<IR::NodeRef> path;
  {
    auto cur_node_ref = base_node_ref;
    auto dest_node_ref = lt.node(ref);
    while (cur_node_ref != dest_node_ref) {
      // path.emplace(path.begin(), cur_node_ref);
      path.emplace_back(cur_node_ref);
      const auto &cur_node = lt.ir.node(cur_node_ref);
      auto outputs = cur_node.outputs();
      if (!outputs.size()) {
        break;
      }
      cur_node_ref = outputs.at(0);
    }
  }

  // node by node we accumulate constraints
  // starting from the base
  std::unordered_map<Symbol, Expr, Hash<Symbol>> out_cs;
  for (auto cur_node_ref : path) {
    auto &cur_node = lt.ir.node(cur_node_ref);
    auto cs = get_constraints(cur_node_ref);
    for (const auto &c : cs) {
      if (base_syms.count(c.first)) {
        ASSERT(!out_cs.count(c.first));
        out_cs.emplace(c.first, c.second);
      }

      for (auto &p : out_cs) {
        p.second = p.second.replace(c.first, c.second);
      }
    }
  }

  // begin to coalesce constraints
  std::vector<Constraint> all_constraints;
  for (auto cur_node_ref : path) {
    auto &cur_node = lt.ir.node(cur_node_ref);
    bool add_all = false;
    if (constraints.size() == 0) {
      add_all = true;
    }
    for (auto c : cur_node.constraints()) {
      all_constraints.emplace_back(c);
      if (c.first.type() != Expr::Type::symbol) {
        continue;
      }
      auto sym = c.first.symbol();
      if (add_all) {
        constraints.emplace_back(c);
        continue;
      }
      bool valid = true;
      for (auto cc : constraints) {
        if (c.second.contains(cc.first.symbol())) {
          valid = false;
        }
      }
      if (valid) {
        for (auto &cc : constraints) {
          if (cc.second.contains(sym)) {
            cc.second = cc.second.replace(sym, c.second);
          }
        }
      }
    }
  }
  auto new_constraints = unify(all_constraints);

  // On occassion (e.g. windowed constraints), we pick up
  // dependencies on output variables.  We can safely set these to zero
  // for calculation of offsets and derivatives
  // (They'd go to zero anyway for the calculations)
  for (auto &cc : constraints) {
    for (auto sym : cc.second.symbols()) {
      if (node.has_sym(sym) && vars.count(node.var(sym))) {
        cc.second = cc.second.replace(sym, Expr(0)).simplify();
      }
    }
  }
  constraints.clear();
  for (auto c : out_cs) {
    constraints.emplace_back(std::make_pair(Expr(c.first), c.second));
  }
  return constraints;
}