in src/core/compile.cpp [597:683]
InnerFnType gen_parallel_loop(const LoopTree <, const Auxiliary &aux,
LoopTree::TreeRef ref,
const GenFnType &callback) {
auto tree_node = lt.tree_node(ref);
auto depth = tree_node.depth;
auto loop = tree_node.loop;
auto size = loop.size;
auto tail_size = loop.tail;
auto var_idx = aux.var_idx.at(loop.var);
ASSERT(size > 0);
ASSERT(tail_size >= 0);
std::vector<InnerFnType> fns;
for (auto c : tree_node.children) {
fns.emplace_back(gen_fn(lt, aux, c, callback));
}
auto inner_size = aux.inner_size.at(ref);
auto memory_fn = gen_mem(lt, aux, ref);
// to handle threading, we calculate offsets memory into memory
// for (auto& mem : aux.threading.at(ref)) {
// //mem.idx
//}
auto alloc_off = lt.ir.inputs().size() + lt.ir.outputs().size();
auto offset_memory = [=](const std::vector<void *> &memory_, int i) {
auto memory = memory_;
if (!aux.thread_memory.count(ref)) {
return memory;
}
// some memory is threaded, we have to
// 1. find that memory
// 2. find how that thread strides the memory
// 3. mutate the memory as `address = (address + i * stride)`
// this means we need TreeRef -> { (idx, stride), (idx, stride) }
for (auto &p : aux.thread_memory.at(ref)) {
auto mem_idx = alloc_off + p.first;
auto fmem = (float *)(memory[mem_idx]);
memory[mem_idx] = fmem + i * p.second;
}
return memory;
};
return [=](const std::vector<void *> &memory_, int indices[MAX_DEPTH],
int tails[MAX_DEPTH]) {
auto run = [&](int n_size, int t_size) {
std::vector<std::thread> threads;
for (auto i = 0; i < n_size; ++i) {
auto memory = offset_memory(memory_, i);
threads.emplace_back([=]() {
int indices_[MAX_DEPTH];
std::copy(indices, indices + MAX_DEPTH, indices_);
int tails_[MAX_DEPTH];
std::copy(tails, tails + MAX_DEPTH, tails_);
memory_fn(memory);
for (const auto &fn : fns) {
indices_[depth] = i;
tails_[var_idx] = 0;
fn(memory, indices_, tails_);
}
});
}
for (auto &t : threads) {
t.join();
}
if (t_size) {
auto memory = offset_memory(memory_, n_size);
memory_fn(memory);
for (const auto &fn : fns) {
indices[depth] = n_size;
tails[var_idx] = t_size;
fn(memory, indices, tails);
}
}
};
auto tail = tails[var_idx];
if (tail) {
auto N = tail / inner_size;
auto T = tail % inner_size;
run(N, T);
return;
}
run(size, tail_size);
};
}