LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2258:2417]


  LogicalResult matchAndRewrite(
      mhlo::ReduceWindowOp op, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const override {
    auto loc = op.getLoc();
    int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
    if (rank != 4 && rank != 5) {
      return rewriter.notifyMatchFailure(
          op, "expected NHWC/NDHWC pooling-based op");
    }

    if (op.padding() && !isSplatValue(*op.padding(), 0)) {
      return rewriter.notifyMatchFailure(op, "require paddings are all zero");
    }

    int last_dim = rank - 1;
    SmallVector<int64_t, 2> fake_window_shapes;
    for (int i = 1; i < last_dim; ++i) {
      fake_window_shapes.push_back(
          op.window_dimensions().getValues<int64_t>()[i]);
    }

    if (op.window_strides() &&
        (op.window_strides().getValue().getValues<int64_t>()[0] != 1 ||
         op.window_strides().getValue().getValues<int64_t>()[last_dim] != 1)) {
      return rewriter.notifyMatchFailure(
          op, "expected window_strides to be [1,x,y,(z),1]");
    }
    if (op.window_dimensions() &&
        (op.window_dimensions().getValues<int64_t>()[0] != 1 ||
         op.window_dimensions().getValues<int64_t>()[last_dim] != 1)) {
      return rewriter.notifyMatchFailure(
          op, "expected window_dimensions to be [1,x,y,(z),1]");
    }

    Attribute strides;
    SmallVector<int64_t> vec;
    if (op.window_stridesAttr()) {
      for (int i = 1; i < last_dim; ++i) {
        vec.push_back(op.window_strides().getValue().getValues<int64_t>()[i]);
      }
    } else {
      vec.assign(rank - 2, 1);
    }
    strides = rewriter.getI64VectorAttr(vec);

    Attribute dilations;
    vec.clear();
    if (op.window_dilations()) {
      for (int i = 1; i < last_dim; ++i) {
        vec.push_back(op.window_dilations().getValue().getValues<int64_t>()[i]);
      }
    } else {
      vec.assign(rank - 2, 1);
    }
    dilations = rewriter.getI64VectorAttr(vec);

    SmallVector<Value> pooling_ops;

    ValueRange inputs = adaptor.inputs();
    ValueRange init_values = adaptor.init_values();
    for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
      OpResult result = std::get<0>(it);
      Value input = std::get<1>(it);
      Value init_value = std::get<2>(it);
      auto result_type = result.getType().cast<ShapedType>();
      if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
        return rewriter.notifyMatchFailure(op,
                                           "expected element type to be f32");
      }

      // Create a fake window dimension.
      auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
          loc, fake_window_shapes, result_type.getElementType());

      SmallVector<Value> result_dynamic_dims;
      for (auto& en : llvm::enumerate(result_type.getShape())) {
        if (en.value() != ShapedType::kDynamicSize) continue;
        Value dim_size = rewriter.create<tensor::DimOp>(loc, input, en.index());
        if (en.index() == 0 || en.index() == rank - 1) {
          // batch dims and channel dims can be derived from input dims
          // directly.
          result_dynamic_dims.push_back(dim_size);
        } else {
          auto i = en.index() - 1;
          auto stride =
              strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
          auto dilation =
              dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
          // let j = i * stride
          // output[i] = reduce( input[j, j + window_size * dilation) )
          Value offset = rewriter.create<arith::ConstantIndexOp>(
              loc, fake_window_shapes[i] * dilation);
          dim_size = rewriter.create<arith::SubIOp>(loc, dim_size, offset);
          dim_size = rewriter.create<arith::DivUIOp>(
              loc, dim_size,
              rewriter.create<arith::ConstantIndexOp>(loc, stride));
          dim_size = rewriter.create<arith::AddIOp>(
              loc, dim_size, rewriter.create<arith::ConstantIndexOp>(loc, 1));
          result_dynamic_dims.push_back(dim_size);
        }
      }
      Value init_tensor = rewriter.create<linalg::InitTensorOp>(
          loc, result_dynamic_dims, result_type.getShape(),
          result_type.getElementType());

      init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
      Value filled_init_tensor =
          rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
              .getResult(0);
      auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
        return cast<linalg::LinalgOp>(
            rewriter
                .create<std::remove_pointer_t<decltype(type_ptr)>>(
                    loc, ArrayRef<Type>{result_type},
                    ValueRange{input, fake_window_dims.getResult()},
                    filled_init_tensor, strides, dilations,
                    PruneAttributeList(op))
                .getOperation());
      };
      linalg::LinalgOp pooling_op;
      PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
      switch (pooling_type) {
        case PoolingType::k2DMin: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
          break;
        }
        case PoolingType::k3DMin: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
          break;
        }
        case PoolingType::k2DMax: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
          break;
        }
        case PoolingType::k3DMax: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
          break;
        }
        case PoolingType::k2DAdd: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
          break;
        }
        case PoolingType::k3DAdd: {
          pooling_op =
              create_op(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
          break;
        }
        case PoolingType::kInvalid:
          return rewriter.notifyMatchFailure(op, "unknown reduction operation");
      }
      pooling_ops.push_back(pooling_op->getResult(0));
    }
    rewriter.replaceOp(op, pooling_ops);
    return success();
  }