LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [1928:2063]


  LogicalResult matchAndRewrite(
      mhlo::ConvOp op, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const override {
    if (op.batch_group_count() != 1) return failure();

    if (op.padding() && !isSplatValue(*op.padding(), 0)) {
      return rewriter.notifyMatchFailure(op,
                                         "non-zero padding unsupported yet");
    }

    if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1))) {
      return rewriter.notifyMatchFailure(
          op, "non-one lhs- dialation unsupported yet");
    }

    if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
            op.dimension_numbers()) {
      // Make sure that this is 2-D convolution.
      const auto spatial_rank =
          llvm::size(dimension_numbers.getInputSpatialDimensions());
      if (spatial_rank != 2) {
        return rewriter.notifyMatchFailure(op,
                                           "only support 2-D cases for now");
      }

      // Make sure that this is depthwise convolution.
      int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
      int64_t input_feature_count =
          op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
      if (op.feature_group_count() != input_feature_count) {
        return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
      }

      // Make sure that this convolution has a canonical form.
      if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
        return rewriter.notifyMatchFailure(op, "does not have canonical form");
      }
    }

    DenseIntElementsAttr window_strides;
    if (op.window_strides()) {
      window_strides = op.window_strides().getValue();
    } else {
      window_strides = rewriter.getI64VectorAttr({1, 1});
    }

    DenseIntElementsAttr rhs_dilation;
    if (op.rhs_dilation()) {
      rhs_dilation = op.rhs_dilation().getValue();
    } else {
      rhs_dilation = rewriter.getI64VectorAttr({1, 1});
    }

    Location loc = op.getLoc();
    Value input = adaptor.lhs();
    Value filter = adaptor.rhs();
    auto result_type = op.getResult().getType().cast<RankedTensorType>();
    if (!result_type.hasStaticShape()) {
      return rewriter.notifyMatchFailure(op,
                                         "expected output has static shapes");
    }

    auto filter_dims =
        llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());

    auto get_indices_vector = [](int start, int end) {
      return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
    };

    if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
      // For cases where channel multiplier != 1
      auto output_dims = result_type.getShape();
      auto channel_multiplier = filter_dims[3];
      SmallVector<int64_t> reshaped_output_dims;
      reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
      reshaped_output_dims.push_back(channel_multiplier);
      reshaped_output_dims[3] /= channel_multiplier;

      Value init_tensor = rewriter.create<linalg::InitTensorOp>(
          loc, reshaped_output_dims, result_type.getElementType());
      auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
      Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
      Value zero_tensor =
          rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);

      auto reshaped_output_type = RankedTensorType::get(
          reshaped_output_dims, result_type.getElementType());
      auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
          op.getLoc(), reshaped_output_type, ValueRange{input, filter},
          ValueRange{zero_tensor}, window_strides, rhs_dilation,
          PruneAttributeList(op));

      // Create a Linalg reshape op that converts the output from 5 dimensions
      // into 4 dimensions (by collapsing the last two dimensions). This is
      // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
      // 5 dimensions for the output.
      SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
          get_indices_vector(0, 1), get_indices_vector(1, 2),
          get_indices_vector(2, 3), get_indices_vector(3, 5)};
      rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
          op, result_type, conv.getResult(0), collapsed_dim_list);
    } else {
      // For cases where channel multiplier == 1
      Value init_tensor = rewriter.create<linalg::InitTensorOp>(
          loc, result_type.getShape(), result_type.getElementType());
      auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
      Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
      Value zero_tensor =
          rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);

      // Create a Linalg reshape op that converts the filter from 4 dimensions
      // into 3 dimensions (by droping the unit dimension). This is needed
      // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
      // dimensions for the filter.

      filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
      filter_dims.pop_back();

      RankedTensorType filter_shape =
          RankedTensorType::get(filter_dims, op.getType().getElementType());

      SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
          get_indices_vector(0, 1), get_indices_vector(1, 2),
          get_indices_vector(2, 4)};

      Value reshaped_filter = rewriter.create<tensor::CollapseShapeOp>(
          loc, filter_shape, filter, collapsed_dim_list);

      rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
          op, result_type, ValueRange{input, reshaped_filter},
          ValueRange{zero_tensor}, window_strides, rhs_dilation,
          PruneAttributeList(op));
    }

    return success();
  }