in src/frontends/lazy.cpp [144:243]
IR::NodeRef TensorImpl::resolve(
IR& ir, std::unordered_map<int, std::pair<IR::VarRef, int64_t>>& var_map,
std::unordered_map<const TensorImpl*, IR::NodeRef>& impl_map) const {
if (impl_map.count(this)) {
return impl_map.at(this);
}
std::vector<IR::NodeRef> node_deps;
std::vector<IR::VarRef> vars;
std::vector<Constraint> node_constraints;
std::unordered_map<int, IR::VarRef> sym_var_map;
for (const auto& d : deps_) {
auto node_ref = d->resolve(ir, var_map, impl_map);
node_deps.emplace_back(node_ref);
}
IR::NodeRef node_ref = -1;
for (const auto& s : shape()) {
if (!var_map.count(s.id())) {
ASSERT(size_constraints().count(s.id()))
<< "unbound variable in compute " << s.name() << " (id: " << s.id()
<< ")";
auto expr = size_constraints().at(s.id());
ASSERT(expr.can_evaluate()) << "can't resolve size";
auto size = static_cast<int64_t>(expr.evaluate());
std::stringstream s_name;
s_name << s.name();
s_name << "_";
s_name << s.id();
auto var = ir.create_var(s_name.str());
ASSERT(var_map.count(s.id()) == 0);
var_map[s.id()] = std::make_pair(var, size);
}
auto& p = var_map.at(s.id());
vars.emplace_back(p.first);
}
for (const auto& p : var_map) {
sym_var_map[p.first] = p.second.first;
}
for (const auto& c : constraints_) {
auto in_map = [&](const Expr& e) {
for (const auto& s : e.symbols()) {
if (!sym_var_map.count(s.id())) {
return false;
}
}
return true;
};
if (in_map(c.first) && in_map(c.second)) {
node_constraints.emplace_back(c);
}
}
switch (op_) {
case Operation::name:
ASSERT(node_deps.size() == 1) << "invalid rename (only 1 input allowed)";
node_ref = node_deps[0];
break;
case Operation::view:
node_ref = ir.create_node(Operation::view, node_deps, vars,
node_constraints, sym_var_map);
break;
case Operation::constant:
node_ref = ir.create_node(Operation::read, {}, vars);
ir.add_input(node_ref);
break;
case Operation::add:
node_ref = ir.create_node(Operation::add, node_deps, vars);
break;
case Operation::subtract:
node_ref = ir.create_node(Operation::subtract, node_deps, vars);
break;
case Operation::multiply:
node_ref = ir.create_node(Operation::multiply, node_deps, vars);
break;
case Operation::divide:
node_ref = ir.create_node(Operation::divide, node_deps, vars);
break;
case Operation::max:
node_ref = ir.create_node(Operation::max, node_deps, vars);
break;
case Operation::exp:
node_ref = ir.create_node(Operation::exp, node_deps, vars);
break;
case Operation::sqrt:
node_ref = ir.create_node(Operation::sqrt, node_deps, vars);
break;
case Operation::reciprocal:
node_ref = ir.create_node(Operation::reciprocal, node_deps, vars);
break;
case Operation::negate:
node_ref = ir.create_node(Operation::negate, node_deps, vars);
break;
default:
break;
}
ASSERT(node_ref > -1) << "couldn't resolve node op: " << dump(op_);
impl_map.insert(std::make_pair(this, node_ref));
return node_ref;
}