LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/lower_general_dot.cc [181:288]


  LogicalResult matchAndRewrite(DotGeneralOp op,
                                PatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    auto dot_element_type = getElementTypeOrSelf(op);

    auto dot_numbers = op.dot_dimension_numbers();
    if (!dot_numbers.getLhsBatchingDimensions().empty() ||
        !dot_numbers.getRhsBatchingDimensions().empty()) {
      return failure();
    }

    auto lhs_contracting_dims = dot_numbers.getLhsContractingDimensions();
    auto rhs_contracting_dims = dot_numbers.getRhsContractingDimensions();

    auto lhs = op.lhs();
    auto rhs = op.rhs();

    RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
    RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
    if (!lhs_ty || !rhs_ty) return failure();

    if (!(lhs_contracting_dims.size() == 1 &&
          lhs_contracting_dims.front() == 1)) {
      lhs = ProcessDotArg(op.lhs(), op.getLoc(),
                          dot_numbers.getLhsContractingDimensions(),
                          /*outer_dims_first=*/true, rewriter);
    }

    if (!(rhs_contracting_dims.size() == 1 &&
          rhs_contracting_dims.front() == 0)) {
      rhs = ProcessDotArg(op.rhs(), op.getLoc(),
                          dot_numbers.getRhsContractingDimensions(),
                          /*outer_dims_first=*/false, rewriter);
    }

    // Accept only static shaped types.
    auto lhs_shape_type = lhs.getType().dyn_cast_or_null<ShapedType>();
    auto rhs_shape_type = rhs.getType().dyn_cast_or_null<ShapedType>();
    if (!lhs_shape_type || !rhs_shape_type) return failure();

    // Dot resulting shape.
    auto lhs_shape = lhs_shape_type.getShape();
    auto rhs_shape = rhs_shape_type.getShape();
    auto new_dot_type =
        RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);

    ArrayAttr precision_config;
    if (op.precision_config()) precision_config = *op.precision_config();
    Value new_dot_op = rewriter
                           .create<DotOp>(op.getLoc(), new_dot_type, lhs, rhs,
                                          precision_config)
                           .getResult();
    if (lhs_contracting_dims.size() == (lhs_ty.getRank() - 1) &&
        rhs_contracting_dims.size() == (rhs_ty.getRank() - 1)) {
      rewriter.replaceOp(op, new_dot_op);
      return success();
    }

    ShapedType result_ty = op.getType().cast<ShapedType>();

    // We can avoid all the computation below if we know the static shape.
    if (result_ty.hasStaticShape()) {
      rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, result_ty, new_dot_op);
      return success();
    }

    llvm::SmallVector<int64_t> static_dims;
    llvm::SmallVector<Value> dyn_dims;

    auto getDynamicDims = [&](Value arg,
                              llvm::ArrayRef<int64_t> contracting_dims) {
      RankedTensorType ty = arg.getType().cast<RankedTensorType>();
      int index = 0;
      for (auto contracting_dim : contracting_dims) {
        for (; index < contracting_dim; index++) {
          static_dims.push_back(ty.getDimSize(index));
          dyn_dims.push_back(rewriter.create<mhlo::GetDimensionSizeOp>(
              loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
              rewriter.getI64IntegerAttr(index)));
        }
        index++;
      }

      for (; index < ty.getRank(); index++) {
        static_dims.push_back(ty.getDimSize(index));
        dyn_dims.push_back(rewriter.create<mhlo::GetDimensionSizeOp>(
            loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
            rewriter.getI64IntegerAttr(index)));
      }
    };

    getDynamicDims(op.lhs(), lhs_contracting_dims);
    getDynamicDims(op.rhs(), rhs_contracting_dims);

    Value reshape_dims_tensor = rewriter.create<mhlo::ConcatenateOp>(
        loc,
        RankedTensorType::get({static_cast<int64_t>(dyn_dims.size())},
                              rewriter.getI32Type()),
        dyn_dims, rewriter.getI64IntegerAttr(0));

    Value result = rewriter.create<DynamicReshapeOp>(
        op.getLoc(),
        RankedTensorType::get(static_dims, result_ty.getElementType()),
        new_dot_op, reshape_dims_tensor);

    rewriter.replaceOp(op, result);
    return success();
  }