Stmt MakeAllreduce()

in src/tir/transforms/lower_thread_allreduce.cc [156:430]


  Stmt MakeAllreduce(const CallNode* call) {
    ICHECK(!reduce_combiner_.empty());
    const CommReducerNode* combiner = reduce_combiner_.back();
    size_t size = combiner->result.size();

    const IntImmNode* size_of_args = call->args[0].as<IntImmNode>();
    ICHECK(size_of_args) << call->args[0]->GetTypeKey();
    ICHECK_EQ(size, size_of_args->value);
    Array<PrimExpr> inits = combiner->identity_element;
    std::vector<PrimExpr> values(size);
    std::vector<DataType> types(size);
    PrimExpr cond = call->args[size + 1];
    for (size_t idx = 0; idx < size; ++idx) {
      values[idx] = call->args[1 + idx];
      if (!is_one(cond)) {
        values[idx] = Select(cond, values[idx], inits[idx]);
      }
      types[idx] = values[idx].dtype();
    }
    std::vector<Buffer> buffers(size);
    for (size_t idx = 0; idx < size; ++idx) {
      PrimExpr arg = call->args[2 + size + idx];
      // Loads from boolean buffers may have cast nodes inserted by
      // earlier passes.
      if (auto cast = arg.as<CastNode>()) {
        arg = cast->value;
      }
      buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
    }

    std::unordered_set<const VarNode*> reduce_set;
    for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
      const VarNode* v = call->args[i].as<VarNode>();
      // The simply optimization replace a iteration variable with a constant
      // when extent of the iteration is 1. As threaded IterVar always started from 0,
      // we can just ignore this variable in this case.
      if (v) {
        reduce_set.insert(v);
      } else {
        ICHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0)
            << "arg" << i << "should be a VarNode or IntImmNode";
      }
    }

    size_t nmatch = 0;
    std::vector<ThreadEntry> vred, vpar;
    for (const AttrStmtNode* attr : thread_extents_) {
      ThreadEntry e;
      IterVar iv = Downcast<IterVar>(attr->node);
      e.scope = runtime::ThreadScope::Create(iv->thread_tag);
      e.iv = iv;
      ICHECK_LE(e.scope.rank, 1);
      ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
      if (e.scope.rank == 1) {
        const auto* ptr = attr->value.as<IntImmNode>();
        ICHECK(ptr) << "Need constant extent for reduce set " << iv;
        e.extent = static_cast<int>(ptr->value);
        // ignore variables equal to 0
        if (e.extent == 1) {
          continue;
        }

        if (reduce_set.count(iv->var.get())) {
          vred.push_back(e);
          ++nmatch;
        } else {
          vpar.push_back(e);
        }
      }
    }
    ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context";
    std::sort(vred.begin(), vred.end());
    std::sort(vpar.begin(), vpar.end());
    // the size of each index.
    int reduce_extent, group_extent;
    PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
    PrimExpr group_index = FlattenThread(vpar, &group_extent);

    // the longest contiguous reduce extent after flattening
    int contiguous_reduce_extent = 1;
    std::vector<std::tuple<int, int, bool>> block_threads;  // tuple(dim_index, extent, is_reduce)
    for (const ThreadEntry& thr : vred) {
      if (thr.scope.rank == 1) {  // threadIdx
        block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
      }
    }
    for (const ThreadEntry& thr : vpar) {
      if (thr.scope.rank == 1) {  // threadIdx
        block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
      }
    }
    // sort according to dim_index
    std::sort(block_threads.begin(), block_threads.end());
    for (auto&& thr_attr : block_threads) {
      auto [dim_index, extent, is_reduce] = thr_attr;
      (void)dim_index;  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
      if (is_reduce) {
        contiguous_reduce_extent *= extent;
      } else {
        break;
      }
    }

    std::vector<Stmt> seq;
    std::vector<Buffer> new_alloc_bufs;
    //
    // This is an optimization. For small reduction sizes, it may be beneficial
    // for a single warp to performance the entire reduction. No trips to shared
    // memory and no cross warp synchronizations are required.
    // The following code emits the reduction as follows:
    //
    // Allocate reduction vars v[i], i = 0..size-1
    //
    // for offset from WARP_SIZE to 1 by 2
    //
    //   a    <- load(v[i])
    //   b    <- shuffle_down(load(v[i], offset))
    //   v[i] <- reduction(a, b)
    //
    // broadcast results from lane 0 to all other lanes and store
    // the final reduction result to the proper location.
    //
    // When the thread extent is multiple of warp size, we can use a two-stage
    // warp-level reduction to optimize. This is implemented by applying the
    // algorithm above twice.
    //
    // For example, suppose we want to use 512 threads to reduce 512 elements
    // and the warp size is 32. In this case there are (512 / 32) = 16 warps.
    // In the first stage, each of the 16 warps reduces 32 elements. So after
    // the stage, we have 16 remaining elements to be reduced, one for each warp.
    // We store the 16 elements in shared memory, and start the second stage.
    // In the second stage we use the first 16 lanes of the first warp to reduce
    // the remaining elements, and this reduction can also be optimized by
    // shuffle_down warp-level primitives.
    PrimExpr zero_index = make_const(reduce_index->dtype, 0);
    if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
      std::vector<PrimExpr> reduce_results;
      DataType mask_dtype = DataType::UInt(32);
      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});

      if (reduce_extent <= warp_size_) {
        std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
            values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);

        // Broadcast the reduction result from lane 0 to all other lanes.
        // This avoids to emit predicated stores, as all threads are
        // uniformly writing the same result.
        for (size_t i = 0; i < size; ++i) {
          Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
          PrimExpr val = BufferLoad(buf, {zero_index});
          ICHECK_EQ(val->dtype, types[i]);
          PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val,
                                       reduce_extent * group_index);
          seq.push_back(BufferStore(buf, splat, {zero_index}));
        }
      } else {
        int n_warps = reduce_extent / warp_size_;
        std::vector<Buffer> local_bufs;

        // 1. Create the staging buffer in shared memory.
        std::vector<Buffer> staging_shared_bufs;
        staging_shared_bufs.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          Buffer staging_shared_buf = decl_buffer(
              /*shape=*/{make_const(reduce_index->dtype, n_warps * group_extent)},
              /*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", /*storage_scope=*/"shared");
          staging_shared_bufs.push_back(staging_shared_buf);
          new_alloc_bufs.push_back(staging_shared_buf);
        }

        // 2. First round of allreduce.
        std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
            values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq);
        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

        // 3. Write allreduce results to staging buffer.
        std::vector<Stmt> write_staging_buf;
        write_staging_buf.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
          write_staging_buf.push_back(BufferStore(
              /*buffer=*/staging_shared_bufs[i],
              /*value=*/reduce_results[i],
              /*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
        }
        PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
        seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
        seq.push_back(SyncThread("shared"));

        // 4. Load staging buffer.
        //    Second round of allreduce.
        for (size_t i = 0; i < size; ++i) {
          values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
                                 /*indices=*/{group_index * n_warps + reduce_index});
        }
        std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
            values, types, combiner, reduce_index, n_warps, group_index, mask,
            /*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());

        // 5. Create shared memory buffer(s) of `group_extent` elements, storing
        // the allreduce results so each thread can access.
        std::vector<Stmt> write_result;
        write_result.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
          Buffer broadcast_shared_buf = decl_buffer(
              /*shape=*/{make_const(reduce_index->dtype, group_extent)},
              /*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
          write_result.push_back(
              BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
          // Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
          reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
        }
        seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
        seq.push_back(SyncThread("shared"));
      }

      // Write back allreduce results and update existing allocations.
      for (size_t i = 0; i < size; ++i) {
        ICHECK(!load_remap_.count(buffers[i]->data.get()));
        PrimExpr pred = const_true(types[i].lanes());
        Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
        ICHECK_EQ(reduce_results[i]->dtype, types[i]);
        load_remap_[buffers[i]->data.get()] = reduce_results[i];

        auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
        alloc_remap_[buffers[i]->data.get()] = buf;
        var_remap_[buffers[i]->data.get()] = buf->data;
        buf_remap_[buffers[i].get()] = buf;
      }
    } else {
      std::vector<Buffer> shared_bufs(size);
      if (reduce_extent == 1) {
        // special case, no reduction is needed.
        std::vector<Stmt> stores;
        for (size_t i = 0; i < size; ++i) {
          stores.push_back(BufferStore(buffers[i], values[i], {0}));
        }
        return SeqStmt::Flatten(stores);
      }
      // This sync is necessary because there might be incomplete read of
      // previous iteration on the same buffer.
      seq.emplace_back(SyncThread("shared"));
      for (size_t idx = 0; idx < size; ++idx) {
        shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
                                       types[idx], "red_buf" + std::to_string(idx), "shared");
        seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
                                     {BufIndex(reduce_index, group_index, reduce_extent)}));
      }
      seq.emplace_back(SyncThread("shared"));
      seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
                                        reduce_extent, group_extent, contiguous_reduce_extent));
      for (size_t idx = 0; idx < size; ++idx) {
        ICHECK(!load_remap_.count(buffers[idx]->data.get()));
        PrimExpr pred = const_true(types[idx].lanes());
        BufferLoad load(shared_bufs[idx],
                        {BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
        ICHECK_EQ(load->dtype, types[idx]);
        load_remap_[buffers[idx]->data.get()] = load;
        alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
        var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
        buf_remap_[buffers[idx].get()] = shared_bufs[idx];
      }
    }

    // Fix all local allocations as all statements are built.
    Stmt body = SeqStmt::Flatten(seq);
    for (Buffer buf : new_alloc_bufs) {
      body = DeclBuffer(buf, body);
      body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
    }

    return body;
  }