LogicalResult matchAndRewrite()

in lib/Dialect/lhlo/transforms/lhlo_legalize_to_gpu.cc [54:172]


  LogicalResult matchAndRewrite(
      ReduceOp reduce_op, OpAdaptor /*adaptor*/,
      ConversionPatternRewriter& rewriter) const final {
    auto loc = reduce_op.getLoc();
    // Only support 1d reductions for now.
    int64_t size = 0;
    for (auto result : reduce_op.out()) {
      auto shaped_type = result.getType().dyn_cast<ShapedType>();
      if (!shaped_type || shaped_type.getRank() != 1) {
        return failure();
      }
      auto dim_size = shaped_type.getDimSize(0);
      if (size && size != dim_size) {
        return failure();
      }
      size = dim_size;
    }

    auto reducing_dimension = *reduce_op.dimensions().value_begin<APInt>();

    // Require all inputs to have the same shape.
    int64_t reduce_dim_size = 0;
    for (auto input : reduce_op.inputs()) {
      auto shaped_type = input.getType().dyn_cast<ShapedType>();
      if (!shaped_type || !shaped_type.hasStaticShape()) {
        return failure();
      }
      reduce_dim_size =
          shaped_type.getDimSize(reducing_dimension.getSExtValue());
    }

    // Create a launch that is parallel in the result dimension.
    auto block_size_x = rewriter.create<mlir::arith::ConstantOp>(
        loc, rewriter.getIndexType(),
        rewriter.getIntegerAttr(rewriter.getIndexType(), size));
    auto one = rewriter.create<mlir::arith::ConstantOp>(
        loc, rewriter.getIndexType(),
        rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
    auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
        loc, one, one, one, block_size_x, one, one);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToEnd(&launch_op.body().front());
      auto index = launch_op.getThreadIds().x;

      // Load the initial value and store it to the output.
      for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
        auto init_value =
            rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
        rewriter.create<mlir::memref::StoreOp>(
            loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
      }

      // Insert a loop into the body to compute the reduction. The loop ranges
      // from [0.dim).
      auto zero = rewriter.create<mlir::arith::ConstantOp>(
          loc, rewriter.getIndexType(),
          rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
      // TODO(b/137624192) Use dimOp to make it shape independent.
      auto upper = rewriter.create<mlir::arith::ConstantOp>(
          loc, rewriter.getIndexType(),
          rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
      auto step = rewriter.create<mlir::arith::ConstantOp>(
          loc, rewriter.getIndexType(),
          rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
      auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);

      rewriter.setInsertionPointToStart(loop.getBody());
      // Compute memrefs for the value to reduce. This makes it easier to just
      // inline the body.
      auto output = *reduce_op.out().begin();
      auto resType = MemRefType::get(
          llvm::None, getElementTypeOrSelf(output.getType()),
          makeStridedLinearLayoutMap(llvm::None,
                                     MemRefType::getDynamicStrideOrOffset(),
                                     rewriter.getContext()));
      OpFoldResult offset = launch_op.getThreadIds().x;
      auto oneAttr = rewriter.getI64IntegerAttr(1);
      OpFoldResult size = oneAttr;
      OpFoldResult stride = oneAttr;
      auto accumulator = rewriter.create<memref::SubViewOp>(
          loc, resType, output, offset, size, stride);
      llvm::SmallVector<Value, 4> indexings;
      Value input_buffer = reduce_op.inputs().front();
      auto input_type_rank =
          input_buffer.getType().cast<MemRefType>().getRank();

      Value input = *reduce_op.operand_begin();
      SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
          llvm::seq<int>(0, input_type_rank), [&](int dim) -> OpFoldResult {
            return dim == reducing_dimension ? loop.getInductionVar()
                                             : launch_op.getThreadIds().x;
          }));
      SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
      SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
      auto rhs = rewriter.create<memref::SubViewOp>(
          loc, accumulator.getType(), input, offsets, sizes, strides);

      // Now copy over the actual body of the reduction, leaving out the
      // terminator.
      BlockAndValueMapping mapping;
      mapping.map(reduce_op.body().getArgument(0), accumulator);
      mapping.map(reduce_op.body().getArgument(1), rhs);
      mapping.map(reduce_op.body().getArgument(2), accumulator);
      for (auto& nested : reduce_op.body().front().without_terminator()) {
        auto* clone = rewriter.clone(nested, mapping);
        for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
          mapping.map(std::get<0>(pair), std::get<1>(pair));
        }
      }

      // Finally, insert the terminator for the launchOp.
      rewriter.setInsertionPointToEnd(&launch_op.body().front());
      rewriter.create<mlir::gpu::TerminatorOp>(loc);
    }

    rewriter.eraseOp(reduce_op);
    return success();
  };