in src/core/compile.cpp [1052:1102]
InnerFnTypeImproved Compiler::gen_reset(LoopTree::TreeRef ref) const {
std::vector<std::tuple<int, int64_t, float>> resets;
for (const auto &p : allocations) {
const auto &alloc = p.second;
if (alloc.lca == ref) {
const auto &node = lt.ir.node(alloc.node_ref);
switch (node.op()) {
case Operation::add:
resets.emplace_back(alloc.mem_idx, alloc.size(), 0.0);
break;
case Operation::subtract:
resets.emplace_back(alloc.mem_idx, alloc.size(), 0.0);
break;
case Operation::multiply:
resets.emplace_back(alloc.mem_idx, alloc.size(), 1.0);
break;
case Operation::divide:
resets.emplace_back(alloc.mem_idx, alloc.size(), 1.0);
break;
case Operation::max:
resets.emplace_back(alloc.mem_idx, alloc.size(),
-std::numeric_limits<float>::max());
break;
// memory ops
case Operation::read:
case Operation::write:
case Operation::view:
// unary ops
case Operation::exp:
case Operation::sqrt:
case Operation::negate:
case Operation::reciprocal:
break;
default:
ASSERT(0) << "cannot generate reset for op: "
<< lt.ir.dump(alloc.node_ref);
}
if (node.op() == Operation::add) {
} else if (node.op() == Operation::multiply) {
}
}
}
return [=](const std::vector<void *> &memory, int indices[MAX_DEPTH]) {
for (const auto &reset : resets) {
for (int64_t i = 0; i < std::get<1>(reset); ++i) {
reinterpret_cast<float *>(memory[std::get<0>(reset)])[i] =
std::get<2>(reset);
}
}
};
}