Stmt TransformReductionBlock()

in src/tir/transforms/lower_cross_thread_reduction.cc [294:527]


Stmt TransformReductionBlock(const BlockRealizeNode* realize,            //
                             const Optional<Array<Buffer>>& it_buffers,  //
                             const Array<Buffer>& ct_buffers,            //
                             const Array<Buffer>& wb_buffers,            //
                             const Array<PrimExpr>& old_wb_indices,      //
                             const CommReducer& reducer,                 //
                             const Array<PrimExpr>& combiner_rhs,        //
                             const std::vector<const ForNode*>& reduction_loops) {
  int n_buffers = wb_buffers.size();
  const BlockNode* block = realize->block.get();

  auto f_create_buffer_regions = [](Array<Buffer> buffers) {
    Array<BufferRegion> regions;
    regions.reserve(buffers.size());
    for (const Buffer& buffer : buffers) {
      regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
    }
    return regions;
  };

  Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
  Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
  if (it_buffers.defined()) {
    it_buffer_regions = f_create_buffer_regions(it_buffers.value());
  }
  // In total, the block is transformed into at most 4 statements
  // - Stmt 1: initialize the buffer for in-thread reduction
  // - Stmt 2: do in-thread reduction
  // - Stmt 3: do cross-thread reduction
  // - Stmt 4: write cross-thread reduction result to the original buffer
  Array<Stmt> stmts;
  stmts.reserve(4);
  // Stmt 1: initialize the buffer for in-thread reduction
  if (it_buffers.defined()) {
    Array<Stmt> inits;
    inits.reserve(n_buffers);
    for (int i = 0; i < n_buffers; ++i) {
      inits.push_back(
          BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)}));
    }
    stmts.push_back(BlockRealize(/*iter_values=*/{},
                                 /*predicate=*/const_true(),
                                 /*block=*/
                                 Block(/*iter_vars=*/{},
                                       /*reads=*/{},
                                       /*writes=*/it_buffer_regions.value(),
                                       /*name_hint=*/block->name_hint + "_in_thread_init",
                                       /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0])));
  }
  // Stmt 2: do in-thread reduction
  {
    Optional<BlockRealize> new_realize = NullOpt;
    // If need to generate in-thread reduction,
    // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize
    // otherwise, directly remove given BlockRealize
    if (it_buffers.defined()) {
      ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
      new_block->reads = std::move(new_block->reads);
      new_block->writes = it_buffer_regions.value();
      new_block->name_hint = new_block->name_hint + "_in_thread";
      new_block->body =
          BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body));
      new_block->init = NullOpt;
      ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
      n->block = Block(new_block);
      new_realize = BlockRealize(n);
    }
    For loop = GetRef<For>(reduction_loops[0]);
    if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) {
      stmts.push_back(stmt.value());
    }
  }
  // Stmt 3: do cross-thread reduction
  {
    // Step 3.1. Create the parameters to the intrinsic
    Array<PrimExpr> parameters;
    parameters.reserve(reduction_loops.size() + 4);
    // 1-st argument: number of buffers
    parameters.push_back(make_const(DataType::UInt(32), n_buffers));
    // Next `n_buffers` arguments: sources
    if (it_buffers.defined()) {
      for (int i = 0; i < n_buffers; ++i) {
        parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)}));
      }
    } else {
      parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end());
    }
    // Next argument: predicate
    parameters.push_back(const_true());
    // Next `n_buffers` arguments: destinations
    for (int i = 0; i < n_buffers; ++i) {
      parameters.push_back(BufferLoad(ct_buffers[i], {0}));
    }
    // Next arguments: all the reduction threads
    for (const ForNode* reduction_loop : reduction_loops) {
      if (reduction_loop->thread_binding.defined()) {
        parameters.push_back(reduction_loop->loop_var);
      }
    }
    // Step 3.2. Create the block and the block-realize.
    Array<IterVar> iter_vars{nullptr};
    Array<PrimExpr> bindings{nullptr};
    Array<BufferRegion> reads{nullptr};
    if (it_buffers.defined()) {
      iter_vars = Array<IterVar>{};
      bindings = Array<PrimExpr>{};
      reads = it_buffer_regions.value();
    } else {
      iter_vars = block->iter_vars;
      bindings = realize->iter_values;
      reads = block->reads;
    }
    stmts.push_back(BlockRealize(
        /*iter_values=*/std::move(bindings),
        /*predicate=*/const_true(),
        /*block=*/
        Block(/*iter_vars=*/std::move(iter_vars),
              /*reads=*/std::move(reads),
              /*writes=*/ct_buffer_regions,
              /*name_hint=*/block->name_hint + "_cross_thread",
              /*body=*/
              AttrStmt(/*node=*/reducer,
                       /*attr_key=*/tir::attr::reduce_scope,
                       /*value=*/make_zero(DataType::Handle()),
                       /*body=*/
                       Evaluate(Call(/*dtype=*/DataType::Handle(),
                                     /*op=*/tir::builtin::tvm_thread_allreduce(),
                                     /*args=*/std::move(parameters)))))));
  }
  // Stmt 4: write cross-thread reduction result to the original buffer
  {
    ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
    int n_iter = static_cast<int>(block->iter_vars.size());
    Array<IterVar> iter_vars;
    Array<PrimExpr> bindings;
    Map<Var, Var> var_map;
    iter_vars.reserve(n_iter);
    bindings.reserve(n_iter);
    for (int i = 0; i < n_iter; ++i) {
      const IterVar& iter_var = block->iter_vars[i];
      const PrimExpr& binding = realize->iter_values[i];
      if (iter_var->iter_type != kCommReduce) {
        IterVar new_iter_var{nullptr};
        {
          ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
          ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
          n->var = Var(v);
          new_iter_var = IterVar(n);
        }
        iter_vars.push_back(new_iter_var);
        bindings.push_back(binding);
        var_map.Set(iter_var->var, new_iter_var->var);
      }
    }
    Array<Stmt> wb_updates;
    Array<BufferRegion> wb_regions;
    wb_updates.reserve(n_buffers);
    wb_regions.reserve(n_buffers);
    int n_dim = static_cast<int>(old_wb_indices.size());
    Array<Range> region = Substitute(block->writes[0]->region, var_map);
    Array<PrimExpr> wb_indices;
    wb_indices.reserve(n_dim);
    for (int d = 0; d < n_dim; ++d) {
      wb_indices.push_back(Substitute(old_wb_indices[d], var_map));
    }
    for (int i = 0; i < n_buffers; ++i) {
      wb_updates.push_back(
          BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
      wb_regions.push_back(BufferRegion(wb_buffers[i], region));
    }

    // Construct the predicate of the write-back block. It is the conjunction of
    // - each predicate clause of the original block which contains spatial loop var, and
    // - `t == 0` for each reduction thread dim when the write-back buffer is not local.
    PrimExpr wb_predicate = const_true();
    std::unordered_set<const VarNode*> reduction_loop_vars;
    reduction_loop_vars.reserve(reduction_loops.size());
    for (const ForNode* reduction_loop : reduction_loops) {
      reduction_loop_vars.insert(reduction_loop->loop_var.get());
    }
    PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) {
      if (const auto* and_node = obj.as<AndNode>()) {
        Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
        for (PrimExpr sub_expr : sub_exprs) {
          if (sub_expr->IsInstance<AndNode>()) {
            continue;
          }
          bool is_reduction = [sub_expr, &reduction_loop_vars]() {
            Array<Var> vars = UndefinedVars(sub_expr);
            for (Var var : vars) {
              if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) {
                return true;
              }
            }
            return false;
          }();
          if (!is_reduction) {
            wb_predicate = wb_predicate && sub_expr;
          }
        }
        return true;
      }
      return false;
    });
    if (wb_buffers[0].scope() != "local") {
      for (const ForNode* loop : reduction_loops) {
        if (loop->thread_binding.defined()) {
          wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
        }
      }
    }

    stmts.push_back(BlockRealize(
        /*iter_values=*/std::move(bindings),
        /*predicate=*/wb_predicate,
        /*block=*/
        Block(/*iter_vars=*/std::move(iter_vars),
              /*reads=*/std::move(ct_buffer_regions),
              /*writes=*/std::move(wb_regions),
              /*name_hint=*/block->name_hint + "_write_back",
              /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
  }
  // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx
  Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
  for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) {
    const ForNode* loop = *rit;
    if (loop->thread_binding.defined()) {
      ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
      n->body = std::move(new_stmt);
      new_stmt = For(n);
    }
  }
  return new_stmt;
}