in src/core/compile.cpp [1723:1810]
Compiler::Allocation Compiler::gen_alloc(IR::NodeRef node_ref) const {
const auto &inputs = lt.ir.inputs();
const auto &outputs = lt.ir.outputs();
int mem_idx = -1;
for (auto i = 0; i < inputs.size(); ++i) {
if (inputs.at(i) == node_ref) {
mem_idx = i;
}
}
for (auto i = 0; i < outputs.size(); ++i) {
if (outputs.at(i) == node_ref) {
mem_idx = i + inputs.size();
}
}
// we need to find a new spot to store this
if (mem_idx == -1) {
mem_idx = inputs.size() + outputs.size();
for (const auto &p : allocations) {
// these allocations already have a spot
if (p.second.mem_idx >= (inputs.size() + outputs.size())) {
mem_idx++;
}
}
}
const auto &node = lt.ir.node(node_ref);
if (!lt.scheduled.count(node_ref)) {
std::vector<int64_t> sizes;
if (node.op() == Operation::read || node.op() == Operation::write) {
for (auto v : node.vars()) {
sizes.emplace_back(var_sizes.at(v));
}
return Allocation(mem_idx, node_ref, sizes, -1);
}
return Allocation(mem_idx, node_ref);
}
std::function<std::vector<LoopTree::TreeRef>(IR::NodeRef nr, bool io_switch)>
get_scheduled_deps;
get_scheduled_deps = [&](IR::NodeRef nr,
bool io_switch) -> std::vector<LoopTree::TreeRef> {
auto &n = lt.ir.node(nr);
std::vector<LoopTree::TreeRef> dep_refs;
for (const auto &dep_ref : (io_switch ? n.inputs() : n.outputs())) {
if (!lt.scheduled.count(dep_ref)) {
if (lt.ir.node(dep_ref).op() == Operation::write) {
dep_refs.emplace_back(-1);
continue;
}
for (auto dep : get_scheduled_deps(dep_ref, io_switch)) {
dep_refs.emplace_back(dep);
}
} else {
dep_refs.emplace_back(lt.scheduled.at(dep_ref));
}
}
return dep_refs;
};
auto ref = lt.parent(lt.scheduled.at(node_ref));
auto lca = ref;
for (auto tr : get_scheduled_deps(node_ref, false)) {
lca = lt.lca(lca, tr);
}
if (node.op() == Operation::write || node.op() == Operation::read) {
lca = -1;
}
std::unordered_map<IR::VarRef, int64_t> var_sizes;
while (ref != lca) {
auto loop = lt.loop(ref);
ref = lt.parent(ref);
if (!var_sizes.count(loop.var)) {
var_sizes[loop.var] = 1;
}
var_sizes[loop.var] *= loop.size;
var_sizes[loop.var] += loop.tail;
}
std::vector<int64_t> sizes;
for (auto v : node.vars()) {
if (var_sizes.count(v)) {
sizes.emplace_back(var_sizes.at(v));
} else {
sizes.emplace_back(0);
}
}
return Allocation(mem_idx, node_ref, sizes, lca);
}