InnerFnTypeImproved Compiler::gen_reset()

in src/core/compile.cpp [1052:1102]


InnerFnTypeImproved Compiler::gen_reset(LoopTree::TreeRef ref) const {
  std::vector<std::tuple<int, int64_t, float>> resets;
  for (const auto &p : allocations) {
    const auto &alloc = p.second;
    if (alloc.lca == ref) {
      const auto &node = lt.ir.node(alloc.node_ref);
      switch (node.op()) {
        case Operation::add:
          resets.emplace_back(alloc.mem_idx, alloc.size(), 0.0);
          break;
        case Operation::subtract:
          resets.emplace_back(alloc.mem_idx, alloc.size(), 0.0);
          break;
        case Operation::multiply:
          resets.emplace_back(alloc.mem_idx, alloc.size(), 1.0);
          break;
        case Operation::divide:
          resets.emplace_back(alloc.mem_idx, alloc.size(), 1.0);
          break;
        case Operation::max:
          resets.emplace_back(alloc.mem_idx, alloc.size(),
                              -std::numeric_limits<float>::max());
          break;
        // memory ops
        case Operation::read:
        case Operation::write:
        case Operation::view:
        // unary ops
        case Operation::exp:
        case Operation::sqrt:
        case Operation::negate:
        case Operation::reciprocal:
          break;
        default:
          ASSERT(0) << "cannot generate reset for op: "
                    << lt.ir.dump(alloc.node_ref);
      }
      if (node.op() == Operation::add) {
      } else if (node.op() == Operation::multiply) {
      }
    }
  }
  return [=](const std::vector<void *> &memory, int indices[MAX_DEPTH]) {
    for (const auto &reset : resets) {
      for (int64_t i = 0; i < std::get<1>(reset); ++i) {
        reinterpret_cast<float *>(memory[std::get<0>(reset)])[i] =
            std::get<2>(reset);
      }
    }
  };
}