in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2069:2223]
LogicalResult matchAndRewrite(
mhlo::ReduceWindowOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
MLIRContext* ctx = op->getContext();
Location loc = op.getLoc();
llvm::SmallVector<Value> init_values = adaptor.init_values();
llvm::SmallVector<Type> result_types = llvm::to_vector(op.getResultTypes());
auto num_operands = init_values.size();
llvm::SmallVector<int64_t> window_dimensions =
Extract1DVector(op.window_dimensions());
llvm::SmallVector<int64_t> padding;
if (op.padding()) {
padding = Extract1DVector(*op.padding());
}
llvm::SmallVector<int64_t> base_dilations;
if (op.window_dilations()) {
base_dilations = Extract1DVector(*op.base_dilations());
if (llvm::any_of(base_dilations, [](int64_t& x) { return x != 1; }))
return failure();
}
llvm::SmallVector<int64_t> window_strides(window_dimensions.size(), 1);
if (op.window_strides()) {
window_strides = Extract1DVector(*op.window_strides());
}
llvm::SmallVector<int64_t> window_dilations(window_dimensions.size(), 1);
if (op.window_dilations()) {
window_dilations = Extract1DVector(*op.window_dilations());
}
auto rank = window_dimensions.size();
SmallVector<AffineExpr, 2> src_exprs;
SmallVector<AffineExpr, 2> window_exprs;
SmallVector<AffineExpr, 2> dst_exprs;
SmallVector<int64_t> filtered_window_dims;
int window_dim = 0;
for (int i = 0; i < rank; i++) {
AffineExpr src_expr = mlir::getAffineDimExpr(i, ctx);
if (window_strides[i] != 1) src_expr = src_expr * window_strides[i];
if (window_dimensions[i] != 1) {
filtered_window_dims.push_back(window_dimensions[i]);
AffineExpr window_expr = mlir::getAffineDimExpr(rank + window_dim, ctx);
window_exprs.push_back(window_expr);
if (window_dilations[i] != 1)
window_expr = window_expr * window_dilations[i];
src_expr = src_expr + window_expr;
window_dim++;
}
src_exprs.push_back(src_expr);
dst_exprs.push_back(mlir::getAffineDimExpr(i, ctx));
}
SmallVector<AffineMap, 4> inferred_maps =
AffineMap::inferFromExprList({src_exprs, window_exprs, dst_exprs});
SmallVector<AffineMap, 4> indexing_maps;
indexing_maps.append(num_operands, inferred_maps[0]);
indexing_maps.append(1, inferred_maps[1]);
indexing_maps.append(num_operands, inferred_maps[2]);
// Setup the initial values.
llvm::SmallVector<Value> broadcast_values;
for (uint64_t i = 0, s = init_values.size(); i < s; i++) {
Value init_value = init_values[i];
auto result_ty = result_types[i].cast<ShapedType>();
if (!result_ty.hasStaticShape()) return failure();
auto broadcast_sizes = rewriter.getI64TensorAttr(result_ty.getShape());
broadcast_values.push_back(rewriter.create<mhlo::BroadcastOp>(
loc, result_ty, init_value, broadcast_sizes));
}
llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.inputs());
// Pad as necessary.
if (llvm::any_of(padding, [](int32_t v) { return v != 0; })) {
llvm::SmallVector<int64_t> static_lows;
llvm::SmallVector<int64_t> static_highs;
for (int i = 0; i < padding.size(); i += 2) {
static_lows.push_back(padding[i]);
static_highs.push_back(padding[i + 1]);
}
for (auto& input : inputs) {
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(
input.getType().cast<ShapedType>().getElementType()));
auto pad_op = rewriter.create<tensor::PadOp>(
loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});
SmallVector<Type, 4> block_arg_types;
block_arg_types.assign(input.getType().cast<ShapedType>().getRank(),
rewriter.getIndexType());
auto& region = pad_op.region();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(
®ion, region.end(), block_arg_types,
SmallVector<Location>(block_arg_types.size(), loc));
rewriter.create<tensor::YieldOp>(loc, zero);
input = pad_op.getResult();
}
}
// Add the extra input for the reduction dimension.
inputs.push_back(rewriter.create<linalg::InitTensorOp>(
loc, filtered_window_dims, rewriter.getF32Type()));
rewriter.setInsertionPoint(op);
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/result_types,
/*inputs=*/inputs,
/*outputs=*/broadcast_values, indexing_maps,
GetParallelAndReductionIterators(rank + filtered_window_dims.size(),
filtered_window_dims.size()),
/*bodyBuild=*/nullptr, PruneAttributeList(op));
// Convert the signature of the body. This includes converting scalar
// tensors to their scalar values and inserting an additional block arg for
// the window arg.
Region& region = linalg_op.region();
rewriter.cloneRegionBefore(op.body(), region, region.end());
TypeConverter::SignatureConversion signature_converter(
inputs.size() + op->getNumResults() - 1);
for (uint64_t i = 0, s = inputs.size(); i < s - 1; i++) {
signature_converter.addInputs(
i, inputs[i].getType().cast<ShapedType>().getElementType());
}
signature_converter.addInputs(
inputs.back().getType().cast<ShapedType>().getElementType());
for (uint64_t i = 0, s = result_types.size(); i < s; i++) {
auto idx = inputs.size() + i - 1;
signature_converter.addInputs(
idx, result_types[i].cast<ShapedType>().getElementType());
}
rewriter.applySignatureConversion(®ion, signature_converter);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}