std::string gen_loop()

in src/backends/cuda/cuda.cpp [343:458]


std::string gen_loop(const LoopTree &lt, const Auxiliary &aux,
                     const CudaAux &cuda_aux, UnrollMap &unroll,
                     LoopTree::TreeRef ref) {
  std::stringstream ss;
  auto depth = lt.tree_node(ref).depth;
  auto loop = lt.loop(ref);
  auto v = lt.ir.var(loop.var).name();
  auto v_depth = loop.var_depth;

  // First emit the main loop, then the tail loop (if there is a tail)
  bool is_tail = cuda_aux.tail.count(loop.var) && cuda_aux.tail.at(loop.var);
  auto inner_size = aux.inner_size.at(ref);
  int loop_size =
      is_tail ? (cuda_aux.tail.at(loop.var) / inner_size) : loop.size;
  ASSERT(loop.size >= loop_size);
  int tail_size =
      is_tail ? (cuda_aux.tail.at(loop.var) % inner_size) : loop.tail;
  if (is_tail) {
    auto consumed = loop_size ? loop_size * inner_size + tail_size : 0;
    const_cast<CudaAux &>(cuda_aux).tail[loop.var] =
        cuda_aux.tail.at(loop.var) - consumed;
  }
  std::stringstream v_ss;
  v_ss << v << "_" << v_depth;
  auto v_str = v_ss.str();

  if (loop_size) {
    if (cuda_aux.unrolled.count(ref)) {
      ASSERT(loop_size > -1) << "Can only unroll sized loops";
      ASSERT(!cuda_aux.threaded.count(ref))
          << "Can only unroll non-threaded loops";
      std::pair<IR::VarRef, int> key = {loop.var, loop.var_depth};

      if (is_tail) {
        auto consumed = loop_size ? loop_size * inner_size + tail_size : 0;
        ss << indent(depth) << "// tail! " << consumed << " consumed "
           << cuda_aux.tail.at(loop.var) << " left \n";
      }
      ss << indent(depth) << "// unrolling " << v_str
         << (is_tail ? " (tail)" : "") << "\n";
      for (auto i = 0; i < loop_size; ++i) {
        auto reset_str = gen_mem_decl(lt, aux, cuda_aux, ref);
        if (reset_str.size()) {
          ss << indent(depth) << "{\n";
          ss << reset_str;
        }
        unroll[key] = i;
        for (auto c : lt.tree_node(ref).children) {
          ss << gen_cuda(lt, aux, cuda_aux, unroll, c);
        }
        if (reset_str.size()) {
          ss << indent(depth) << "}\n";
        }
      }
      unroll.erase(key);
    } else if (cuda_aux.threaded.count(ref)) {
      auto inner = cuda_aux.threaded.at(ref);
      ASSERT(inner > -1 && "Never calcualated inner size of threaded loop");
      size_t needed_threads = thread_scope(lt, cuda_aux, ref);
      // ss << indent(depth) << "if ((blockIdx.x * blockDim.x + threadIdx.x) / "
      // << 1 << " < " << needed_threads << ") ";
      ss << indent(depth) << "{\n";
      ss << indent(depth) << "int " << v_str << " = (_tid / " << inner << ") % "
         << loop.size << ";\n";
      if (loop_size != loop.size) {
        ss << indent(depth) << "if (" << v_str << " < " << loop_size << ") {\n";
      }
    } else {
      // ss << indent(depth) << "#pragma unroll\n";
      if (is_tail) {
        auto consumed = loop_size ? loop_size * inner_size + tail_size : 0;
        ss << indent(depth) << "// tail! " << consumed << " consumed "
           << cuda_aux.tail.at(loop.var) << " left \n";
      }
      ss << indent(depth) << "for (int " << v_str << " = 0;";
      ss << " " << v_str << " < " << loop_size << ";";
      ss << " ++" << v_str << ") {\n";
    }
    if (!cuda_aux.unrolled.count(ref)) {
      ss << gen_mem_decl(lt, aux, cuda_aux, ref);
      for (auto c : lt.tree_node(ref).children) {
        ss << gen_cuda(lt, aux, cuda_aux, unroll, c);
      }
      ss << indent(depth) << "}\n";
      if (cuda_aux.threaded.count(ref)) {
        if (loop_size != loop.size) {
          ss << indent(depth) << "}\n";
        }
      }
    }
  };

  if (tail_size > 0) {
    ss << indent(depth) << "// Tail logic for " << v_str << " (L" << ref
       << ")\n";
    if (cuda_aux.threaded.count(ref)) {
      auto inner = cuda_aux.threaded.at(ref);
      size_t needed_threads = thread_scope(lt, cuda_aux, ref);
      ss << indent(depth) << "if ("
         << "0 == (_tid / " << inner << ") % " << loop.size << ") {\n";
    } else {
      ss << indent(depth) << "{\n";
    }
    ss << indent(depth) << "int " << v_str << " = " << loop_size << ";\n";
    ss << indent(depth) << "{ // tail\n";
    ss << gen_mem_decl(lt, aux, cuda_aux, ref);
    for (auto c : lt.tree_node(ref).children) {
      const_cast<CudaAux &>(cuda_aux).tail[loop.var] = tail_size;
      ss << gen_cuda(lt, aux, cuda_aux, unroll, c);
      const_cast<CudaAux &>(cuda_aux).tail[loop.var] = 0;
    }
    ss << indent(depth) << "} // killing tail\n";
    ss << indent(depth) << "}\n";
  }
  return ss.str();
}