std::string gen_mem_decl()

in src/backends/cuda/cuda.cpp [261:325]


std::string gen_mem_decl(const LoopTree &lt, const Auxiliary &aux,
                         const CudaAux &cuda_aux, LoopTree::TreeRef ref,
                         bool declare = true) {
  std::stringstream ss;
  auto depth = ref > -1 ? lt.tree_node(ref).depth : 0;
  std::vector<Allocation> reset_allocs =
      aux.resets.count(ref) ? aux.resets.at(ref) : std::vector<Allocation>{};
  // we can traverse producer to LCA and check for threadedness
  // This is true because of the threading self-consistency invariant
  // If need-be this could change, but it gets really messy
  auto is_shared = [&](const Allocation &alloc) {
    auto p = lt.parent(alloc.producer);
    bool shared = false;
    while (p != alloc.lca) {
      if (cuda_aux.threaded.count(p)) {
        shared = true;
        break;
      }
      p = lt.parent(p);
    }
    if (shared && (p != -1)) {
      // ASSERT(0) << "known shared memory issue when it's not at global top
      // level scope";
    }
    return shared;
  };
  for (auto alloc : reset_allocs) {
    if (lt.ir.node(lt.node(alloc.producer)).outputs().size() == 0) {
      continue;
    }
    ss << indent(depth + 1);
    auto shared = is_shared(alloc);
    if (shared) {
      ss << "__shared__ ";
    }
    // always store to float4 if possible
    auto numel = alloc.size;
    if (alloc.size % 4 == 0) {
      ss << (declare ? "float4 " : "");
      numel = alloc.size / 4;
    } else {
      ss << (declare ? "float " : "");
    }
    ss << "mem_" << alloc.idx;
    ss << "[" << numel << "]";
    if (alloc.should_init) {
      if (shared) {
        ss << "; for (int s_i = 0; s_i < " << numel << "; ++s_i) { ";
        ss << "mem_" << alloc.idx << "[s_i] = ";
        if (alloc.size % 4 == 0) {
          ss << "make_float4(" << alloc.init_val << ", " << alloc.init_val
             << ", " << alloc.init_val << ", " << alloc.init_val << ")";
        } else {
          ss << alloc.init_val;
        }
        ss << "; };\n";
      } else {
        ss << " = {" << alloc.init_val << "};\n";
      }
    } else {
      ss << ";\n";
    }
  }
  return ss.str();
}