std::string gen_guard()

in src/backends/cuda/cuda.cpp [461:524]


std::string gen_guard(const LoopTree &lt, const Auxiliary &aux,
                      const CudaAux &cuda_aux, UnrollMap &unroll,
                      LoopTree::TreeRef ref) {
  std::stringstream ss;
  // note that any threads in sibling trees (invisible to us, but extent)
  // are *necessarily* smaller than our first threaded parent
  // thus we can find our first threaded parent i = cuda_aux.threaded.at(parent)
  // and just check `(tid % i) == 0`
  // for threaded parents that we *do not* care about (not in vars),
  // we check that `tid_var == 0`
  if (ref == -1) {
    return ss.str();
  }
  if (lt.kind(ref) == LoopTree::LOOP) {
    if (cuda_aux.threaded.count(ref)) {
      auto inner = cuda_aux.threaded.at(ref);
      auto parent = lt.parent(ref);
      auto loop = lt.loop(ref);
      auto expected_inner = loop.size * inner;
      if (cuda_aux.threaded.count(parent)) {
        auto outer = cuda_aux.threaded.at(parent);
        if (outer != expected_inner) {
          auto mod = outer / expected_inner;
          ASSERT(mod != 0)
              << "Unexpected threading mismatch cannot be reconciled";
          ss << "((_tid / " << expected_inner << ") % " << mod << " == 0)";
        }
      }
    }
    return ss.str();
  }
  auto vs = lt.ir.loop_vars(lt.node(ref));
  std::unordered_set<IR::VarRef> vars(vs.begin(), vs.end());
  auto parent = lt.parent(ref);
  bool first_parent = false;
  auto last_inner = 1;
  auto last_loop_size = 1;
  while (parent != -1) {
    auto loop = lt.loop(parent);
    auto v = loop.var;
    auto v_depth = loop.var_depth;
    // we need to guard a threaded var we don't care about
    if (cuda_aux.threaded.count(parent)) {
      auto inner = cuda_aux.threaded.at(parent);
      if (vars.count(v) == 0) {
        if (ss.str().size()) {
          ss << " && ";
        }
        auto v_n = lt.ir.var(v).name();
        ss << "(" << v_n << "_" << v_depth << " == 0)";
      } else if (!first_parent) {
        first_parent = true;
        if (inner != 1) {
          if (ss.str().size()) {
            ss << " && ";
          }
          ss << "(_tid % " << inner << " == 0)";
        }
      }
    }
    parent = lt.parent(parent);
  }
  return ss.str();
}