IR::NodeRef TensorImpl::resolve()

in src/frontends/lazy.cpp [144:243]


IR::NodeRef TensorImpl::resolve(
    IR& ir, std::unordered_map<int, std::pair<IR::VarRef, int64_t>>& var_map,
    std::unordered_map<const TensorImpl*, IR::NodeRef>& impl_map) const {
  if (impl_map.count(this)) {
    return impl_map.at(this);
  }
  std::vector<IR::NodeRef> node_deps;
  std::vector<IR::VarRef> vars;
  std::vector<Constraint> node_constraints;
  std::unordered_map<int, IR::VarRef> sym_var_map;

  for (const auto& d : deps_) {
    auto node_ref = d->resolve(ir, var_map, impl_map);
    node_deps.emplace_back(node_ref);
  }
  IR::NodeRef node_ref = -1;

  for (const auto& s : shape()) {
    if (!var_map.count(s.id())) {
      ASSERT(size_constraints().count(s.id()))
          << "unbound variable in compute " << s.name() << " (id: " << s.id()
          << ")";
      auto expr = size_constraints().at(s.id());
      ASSERT(expr.can_evaluate()) << "can't resolve size";
      auto size = static_cast<int64_t>(expr.evaluate());
      std::stringstream s_name;
      s_name << s.name();
      s_name << "_";
      s_name << s.id();
      auto var = ir.create_var(s_name.str());
      ASSERT(var_map.count(s.id()) == 0);
      var_map[s.id()] = std::make_pair(var, size);
    }
    auto& p = var_map.at(s.id());
    vars.emplace_back(p.first);
  }
  for (const auto& p : var_map) {
    sym_var_map[p.first] = p.second.first;
  }
  for (const auto& c : constraints_) {
    auto in_map = [&](const Expr& e) {
      for (const auto& s : e.symbols()) {
        if (!sym_var_map.count(s.id())) {
          return false;
        }
      }
      return true;
    };
    if (in_map(c.first) && in_map(c.second)) {
      node_constraints.emplace_back(c);
    }
  }

  switch (op_) {
    case Operation::name:
      ASSERT(node_deps.size() == 1) << "invalid rename (only 1 input allowed)";
      node_ref = node_deps[0];
      break;
    case Operation::view:
      node_ref = ir.create_node(Operation::view, node_deps, vars,
                                node_constraints, sym_var_map);
      break;
    case Operation::constant:
      node_ref = ir.create_node(Operation::read, {}, vars);
      ir.add_input(node_ref);
      break;
    case Operation::add:
      node_ref = ir.create_node(Operation::add, node_deps, vars);
      break;
    case Operation::subtract:
      node_ref = ir.create_node(Operation::subtract, node_deps, vars);
      break;
    case Operation::multiply:
      node_ref = ir.create_node(Operation::multiply, node_deps, vars);
      break;
    case Operation::divide:
      node_ref = ir.create_node(Operation::divide, node_deps, vars);
      break;
    case Operation::max:
      node_ref = ir.create_node(Operation::max, node_deps, vars);
      break;
    case Operation::exp:
      node_ref = ir.create_node(Operation::exp, node_deps, vars);
      break;
    case Operation::sqrt:
      node_ref = ir.create_node(Operation::sqrt, node_deps, vars);
      break;
    case Operation::reciprocal:
      node_ref = ir.create_node(Operation::reciprocal, node_deps, vars);
      break;
    case Operation::negate:
      node_ref = ir.create_node(Operation::negate, node_deps, vars);
      break;
    default:
      break;
  }
  ASSERT(node_ref > -1) << "couldn't resolve node op: " << dump(op_);
  impl_map.insert(std::make_pair(this, node_ref));
  return node_ref;
}