in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2258:2417]
LogicalResult matchAndRewrite(
mhlo::ReduceWindowOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto loc = op.getLoc();
int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
if (rank != 4 && rank != 5) {
return rewriter.notifyMatchFailure(
op, "expected NHWC/NDHWC pooling-based op");
}
if (op.padding() && !isSplatValue(*op.padding(), 0)) {
return rewriter.notifyMatchFailure(op, "require paddings are all zero");
}
int last_dim = rank - 1;
SmallVector<int64_t, 2> fake_window_shapes;
for (int i = 1; i < last_dim; ++i) {
fake_window_shapes.push_back(
op.window_dimensions().getValues<int64_t>()[i]);
}
if (op.window_strides() &&
(op.window_strides().getValue().getValues<int64_t>()[0] != 1 ||
op.window_strides().getValue().getValues<int64_t>()[last_dim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_strides to be [1,x,y,(z),1]");
}
if (op.window_dimensions() &&
(op.window_dimensions().getValues<int64_t>()[0] != 1 ||
op.window_dimensions().getValues<int64_t>()[last_dim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_dimensions to be [1,x,y,(z),1]");
}
Attribute strides;
SmallVector<int64_t> vec;
if (op.window_stridesAttr()) {
for (int i = 1; i < last_dim; ++i) {
vec.push_back(op.window_strides().getValue().getValues<int64_t>()[i]);
}
} else {
vec.assign(rank - 2, 1);
}
strides = rewriter.getI64VectorAttr(vec);
Attribute dilations;
vec.clear();
if (op.window_dilations()) {
for (int i = 1; i < last_dim; ++i) {
vec.push_back(op.window_dilations().getValue().getValues<int64_t>()[i]);
}
} else {
vec.assign(rank - 2, 1);
}
dilations = rewriter.getI64VectorAttr(vec);
SmallVector<Value> pooling_ops;
ValueRange inputs = adaptor.inputs();
ValueRange init_values = adaptor.init_values();
for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
OpResult result = std::get<0>(it);
Value input = std::get<1>(it);
Value init_value = std::get<2>(it);
auto result_type = result.getType().cast<ShapedType>();
if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
return rewriter.notifyMatchFailure(op,
"expected element type to be f32");
}
// Create a fake window dimension.
auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
loc, fake_window_shapes, result_type.getElementType());
SmallVector<Value> result_dynamic_dims;
for (auto& en : llvm::enumerate(result_type.getShape())) {
if (en.value() != ShapedType::kDynamicSize) continue;
Value dim_size = rewriter.create<tensor::DimOp>(loc, input, en.index());
if (en.index() == 0 || en.index() == rank - 1) {
// batch dims and channel dims can be derived from input dims
// directly.
result_dynamic_dims.push_back(dim_size);
} else {
auto i = en.index() - 1;
auto stride =
strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
auto dilation =
dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
// let j = i * stride
// output[i] = reduce( input[j, j + window_size * dilation) )
Value offset = rewriter.create<arith::ConstantIndexOp>(
loc, fake_window_shapes[i] * dilation);
dim_size = rewriter.create<arith::SubIOp>(loc, dim_size, offset);
dim_size = rewriter.create<arith::DivUIOp>(
loc, dim_size,
rewriter.create<arith::ConstantIndexOp>(loc, stride));
dim_size = rewriter.create<arith::AddIOp>(
loc, dim_size, rewriter.create<arith::ConstantIndexOp>(loc, 1));
result_dynamic_dims.push_back(dim_size);
}
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, result_dynamic_dims, result_type.getShape(),
result_type.getElementType());
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
Value filled_init_tensor =
rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
.getResult(0);
auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
return cast<linalg::LinalgOp>(
rewriter
.create<std::remove_pointer_t<decltype(type_ptr)>>(
loc, ArrayRef<Type>{result_type},
ValueRange{input, fake_window_dims.getResult()},
filled_init_tensor, strides, dilations,
PruneAttributeList(op))
.getOperation());
};
linalg::LinalgOp pooling_op;
PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
switch (pooling_type) {
case PoolingType::k2DMin: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
break;
}
case PoolingType::k3DMin: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
break;
}
case PoolingType::k2DMax: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
break;
}
case PoolingType::k3DMax: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
break;
}
case PoolingType::k2DAdd: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
break;
}
case PoolingType::k3DAdd: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
break;
}
case PoolingType::kInvalid:
return rewriter.notifyMatchFailure(op, "unknown reduction operation");
}
pooling_ops.push_back(pooling_op->getResult(0));
}
rewriter.replaceOp(op, pooling_ops);
return success();
}