LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2069:2223]


  LogicalResult matchAndRewrite(
      mhlo::ReduceWindowOp op, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const override {
    MLIRContext* ctx = op->getContext();
    Location loc = op.getLoc();
    llvm::SmallVector<Value> init_values = adaptor.init_values();
    llvm::SmallVector<Type> result_types = llvm::to_vector(op.getResultTypes());
    auto num_operands = init_values.size();

    llvm::SmallVector<int64_t> window_dimensions =
        Extract1DVector(op.window_dimensions());

    llvm::SmallVector<int64_t> padding;
    if (op.padding()) {
      padding = Extract1DVector(*op.padding());
    }

    llvm::SmallVector<int64_t> base_dilations;
    if (op.window_dilations()) {
      base_dilations = Extract1DVector(*op.base_dilations());
      if (llvm::any_of(base_dilations, [](int64_t& x) { return x != 1; }))
        return failure();
    }

    llvm::SmallVector<int64_t> window_strides(window_dimensions.size(), 1);
    if (op.window_strides()) {
      window_strides = Extract1DVector(*op.window_strides());
    }

    llvm::SmallVector<int64_t> window_dilations(window_dimensions.size(), 1);
    if (op.window_dilations()) {
      window_dilations = Extract1DVector(*op.window_dilations());
    }

    auto rank = window_dimensions.size();
    SmallVector<AffineExpr, 2> src_exprs;
    SmallVector<AffineExpr, 2> window_exprs;
    SmallVector<AffineExpr, 2> dst_exprs;
    SmallVector<int64_t> filtered_window_dims;

    int window_dim = 0;
    for (int i = 0; i < rank; i++) {
      AffineExpr src_expr = mlir::getAffineDimExpr(i, ctx);

      if (window_strides[i] != 1) src_expr = src_expr * window_strides[i];

      if (window_dimensions[i] != 1) {
        filtered_window_dims.push_back(window_dimensions[i]);
        AffineExpr window_expr = mlir::getAffineDimExpr(rank + window_dim, ctx);
        window_exprs.push_back(window_expr);

        if (window_dilations[i] != 1)
          window_expr = window_expr * window_dilations[i];

        src_expr = src_expr + window_expr;
        window_dim++;
      }

      src_exprs.push_back(src_expr);
      dst_exprs.push_back(mlir::getAffineDimExpr(i, ctx));
    }

    SmallVector<AffineMap, 4> inferred_maps =
        AffineMap::inferFromExprList({src_exprs, window_exprs, dst_exprs});

    SmallVector<AffineMap, 4> indexing_maps;

    indexing_maps.append(num_operands, inferred_maps[0]);
    indexing_maps.append(1, inferred_maps[1]);
    indexing_maps.append(num_operands, inferred_maps[2]);

    // Setup the initial values.
    llvm::SmallVector<Value> broadcast_values;
    for (uint64_t i = 0, s = init_values.size(); i < s; i++) {
      Value init_value = init_values[i];
      auto result_ty = result_types[i].cast<ShapedType>();
      if (!result_ty.hasStaticShape()) return failure();

      auto broadcast_sizes = rewriter.getI64TensorAttr(result_ty.getShape());
      broadcast_values.push_back(rewriter.create<mhlo::BroadcastOp>(
          loc, result_ty, init_value, broadcast_sizes));
    }

    llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.inputs());

    // Pad as necessary.
    if (llvm::any_of(padding, [](int32_t v) { return v != 0; })) {
      llvm::SmallVector<int64_t> static_lows;
      llvm::SmallVector<int64_t> static_highs;
      for (int i = 0; i < padding.size(); i += 2) {
        static_lows.push_back(padding[i]);
        static_highs.push_back(padding[i + 1]);
      }
      for (auto& input : inputs) {
        Value zero = rewriter.create<arith::ConstantOp>(
            loc, rewriter.getZeroAttr(
                     input.getType().cast<ShapedType>().getElementType()));
        auto pad_op = rewriter.create<tensor::PadOp>(
            loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});

        SmallVector<Type, 4> block_arg_types;
        block_arg_types.assign(input.getType().cast<ShapedType>().getRank(),
                               rewriter.getIndexType());
        auto& region = pad_op.region();

        OpBuilder::InsertionGuard guard(rewriter);
        rewriter.createBlock(
            &region, region.end(), block_arg_types,
            SmallVector<Location>(block_arg_types.size(), loc));
        rewriter.create<tensor::YieldOp>(loc, zero);

        input = pad_op.getResult();
      }
    }

    // Add the extra input for the reduction dimension.
    inputs.push_back(rewriter.create<linalg::InitTensorOp>(
        loc, filtered_window_dims, rewriter.getF32Type()));

    rewriter.setInsertionPoint(op);
    auto linalg_op = rewriter.create<linalg::GenericOp>(
        loc, /*resultTensors=*/result_types,
        /*inputs=*/inputs,
        /*outputs=*/broadcast_values, indexing_maps,
        GetParallelAndReductionIterators(rank + filtered_window_dims.size(),
                                         filtered_window_dims.size()),
        /*bodyBuild=*/nullptr, PruneAttributeList(op));

    // Convert the signature of the body. This includes converting scalar
    // tensors to their scalar values and inserting an additional block arg for
    // the window arg.
    Region& region = linalg_op.region();
    rewriter.cloneRegionBefore(op.body(), region, region.end());

    TypeConverter::SignatureConversion signature_converter(
        inputs.size() + op->getNumResults() - 1);

    for (uint64_t i = 0, s = inputs.size(); i < s - 1; i++) {
      signature_converter.addInputs(
          i, inputs[i].getType().cast<ShapedType>().getElementType());
    }

    signature_converter.addInputs(
        inputs.back().getType().cast<ShapedType>().getElementType());

    for (uint64_t i = 0, s = result_types.size(); i < s; i++) {
      auto idx = inputs.size() + i - 1;
      signature_converter.addInputs(
          idx, result_types[i].cast<ShapedType>().getElementType());
    }

    rewriter.applySignatureConversion(&region, signature_converter);
    rewriter.replaceOp(op, linalg_op.getResults());
    return success();
  }