InnerFnType gen_parallel_loop()

in src/core/compile.cpp [597:683]


InnerFnType gen_parallel_loop(const LoopTree &lt, const Auxiliary &aux,
                              LoopTree::TreeRef ref,
                              const GenFnType &callback) {
  auto tree_node = lt.tree_node(ref);
  auto depth = tree_node.depth;
  auto loop = tree_node.loop;
  auto size = loop.size;
  auto tail_size = loop.tail;
  auto var_idx = aux.var_idx.at(loop.var);

  ASSERT(size > 0);
  ASSERT(tail_size >= 0);
  std::vector<InnerFnType> fns;
  for (auto c : tree_node.children) {
    fns.emplace_back(gen_fn(lt, aux, c, callback));
  }

  auto inner_size = aux.inner_size.at(ref);
  auto memory_fn = gen_mem(lt, aux, ref);

  // to handle threading, we calculate offsets memory into memory
  // for (auto& mem : aux.threading.at(ref)) {
  //  //mem.idx
  //}
  auto alloc_off = lt.ir.inputs().size() + lt.ir.outputs().size();
  auto offset_memory = [=](const std::vector<void *> &memory_, int i) {
    auto memory = memory_;
    if (!aux.thread_memory.count(ref)) {
      return memory;
    }
    // some memory is threaded, we have to
    // 1. find that memory
    // 2. find how that thread strides the memory
    // 3. mutate the memory as `address = (address + i * stride)`
    // this means we need TreeRef -> { (idx, stride), (idx, stride) }
    for (auto &p : aux.thread_memory.at(ref)) {
      auto mem_idx = alloc_off + p.first;
      auto fmem = (float *)(memory[mem_idx]);
      memory[mem_idx] = fmem + i * p.second;
    }
    return memory;
  };

  return [=](const std::vector<void *> &memory_, int indices[MAX_DEPTH],
             int tails[MAX_DEPTH]) {
    auto run = [&](int n_size, int t_size) {
      std::vector<std::thread> threads;
      for (auto i = 0; i < n_size; ++i) {
        auto memory = offset_memory(memory_, i);
        threads.emplace_back([=]() {
          int indices_[MAX_DEPTH];
          std::copy(indices, indices + MAX_DEPTH, indices_);
          int tails_[MAX_DEPTH];
          std::copy(tails, tails + MAX_DEPTH, tails_);
          memory_fn(memory);
          for (const auto &fn : fns) {
            indices_[depth] = i;
            tails_[var_idx] = 0;
            fn(memory, indices_, tails_);
          }
        });
      }
      for (auto &t : threads) {
        t.join();
      }
      if (t_size) {
        auto memory = offset_memory(memory_, n_size);
        memory_fn(memory);
        for (const auto &fn : fns) {
          indices[depth] = n_size;
          tails[var_idx] = t_size;
          fn(memory, indices, tails);
        }
      }
    };

    auto tail = tails[var_idx];
    if (tail) {
      auto N = tail / inner_size;
      auto T = tail % inner_size;
      run(N, T);
      return;
    }

    run(size, tail_size);
  };
}