in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [1928:2063]
LogicalResult matchAndRewrite(
mhlo::ConvOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (op.batch_group_count() != 1) return failure();
if (op.padding() && !isSplatValue(*op.padding(), 0)) {
return rewriter.notifyMatchFailure(op,
"non-zero padding unsupported yet");
}
if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1))) {
return rewriter.notifyMatchFailure(
op, "non-one lhs- dialation unsupported yet");
}
if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
op.dimension_numbers()) {
// Make sure that this is 2-D convolution.
const auto spatial_rank =
llvm::size(dimension_numbers.getInputSpatialDimensions());
if (spatial_rank != 2) {
return rewriter.notifyMatchFailure(op,
"only support 2-D cases for now");
}
// Make sure that this is depthwise convolution.
int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
int64_t input_feature_count =
op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
if (op.feature_group_count() != input_feature_count) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
// Make sure that this convolution has a canonical form.
if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
return rewriter.notifyMatchFailure(op, "does not have canonical form");
}
}
DenseIntElementsAttr window_strides;
if (op.window_strides()) {
window_strides = op.window_strides().getValue();
} else {
window_strides = rewriter.getI64VectorAttr({1, 1});
}
DenseIntElementsAttr rhs_dilation;
if (op.rhs_dilation()) {
rhs_dilation = op.rhs_dilation().getValue();
} else {
rhs_dilation = rewriter.getI64VectorAttr({1, 1});
}
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<RankedTensorType>();
if (!result_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected output has static shapes");
}
auto filter_dims =
llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
auto get_indices_vector = [](int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
};
if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
// For cases where channel multiplier != 1
auto output_dims = result_type.getShape();
auto channel_multiplier = filter_dims[3];
SmallVector<int64_t> reshaped_output_dims;
reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
reshaped_output_dims.push_back(channel_multiplier);
reshaped_output_dims[3] /= channel_multiplier;
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, reshaped_output_dims, result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
auto reshaped_output_type = RankedTensorType::get(
reshaped_output_dims, result_type.getElementType());
auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
op.getLoc(), reshaped_output_type, ValueRange{input, filter},
ValueRange{zero_tensor}, window_strides, rhs_dilation,
PruneAttributeList(op));
// Create a Linalg reshape op that converts the output from 5 dimensions
// into 4 dimensions (by collapsing the last two dimensions). This is
// needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
// 5 dimensions for the output.
SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 3), get_indices_vector(3, 5)};
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, result_type, conv.getResult(0), collapsed_dim_list);
} else {
// For cases where channel multiplier == 1
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, result_type.getShape(), result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
// Create a Linalg reshape op that converts the filter from 4 dimensions
// into 3 dimensions (by droping the unit dimension). This is needed
// because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
// dimensions for the filter.
filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
filter_dims.pop_back();
RankedTensorType filter_shape =
RankedTensorType::get(filter_dims, op.getType().getElementType());
SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 4)};
Value reshaped_filter = rewriter.create<tensor::CollapseShapeOp>(
loc, filter_shape, filter, collapsed_dim_list);
rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
op, result_type, ValueRange{input, reshaped_filter},
ValueRange{zero_tensor}, window_strides, rhs_dilation,
PruneAttributeList(op));
}
return success();
}