in lib/Dialect/mhlo/transforms/lower_general_dot.cc [181:288]
LogicalResult matchAndRewrite(DotGeneralOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto dot_element_type = getElementTypeOrSelf(op);
auto dot_numbers = op.dot_dimension_numbers();
if (!dot_numbers.getLhsBatchingDimensions().empty() ||
!dot_numbers.getRhsBatchingDimensions().empty()) {
return failure();
}
auto lhs_contracting_dims = dot_numbers.getLhsContractingDimensions();
auto rhs_contracting_dims = dot_numbers.getRhsContractingDimensions();
auto lhs = op.lhs();
auto rhs = op.rhs();
RankedTensorType lhs_ty = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType rhs_ty = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_ty || !rhs_ty) return failure();
if (!(lhs_contracting_dims.size() == 1 &&
lhs_contracting_dims.front() == 1)) {
lhs = ProcessDotArg(op.lhs(), op.getLoc(),
dot_numbers.getLhsContractingDimensions(),
/*outer_dims_first=*/true, rewriter);
}
if (!(rhs_contracting_dims.size() == 1 &&
rhs_contracting_dims.front() == 0)) {
rhs = ProcessDotArg(op.rhs(), op.getLoc(),
dot_numbers.getRhsContractingDimensions(),
/*outer_dims_first=*/false, rewriter);
}
// Accept only static shaped types.
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<ShapedType>();
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<ShapedType>();
if (!lhs_shape_type || !rhs_shape_type) return failure();
// Dot resulting shape.
auto lhs_shape = lhs_shape_type.getShape();
auto rhs_shape = rhs_shape_type.getShape();
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
ArrayAttr precision_config;
if (op.precision_config()) precision_config = *op.precision_config();
Value new_dot_op = rewriter
.create<DotOp>(op.getLoc(), new_dot_type, lhs, rhs,
precision_config)
.getResult();
if (lhs_contracting_dims.size() == (lhs_ty.getRank() - 1) &&
rhs_contracting_dims.size() == (rhs_ty.getRank() - 1)) {
rewriter.replaceOp(op, new_dot_op);
return success();
}
ShapedType result_ty = op.getType().cast<ShapedType>();
// We can avoid all the computation below if we know the static shape.
if (result_ty.hasStaticShape()) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, result_ty, new_dot_op);
return success();
}
llvm::SmallVector<int64_t> static_dims;
llvm::SmallVector<Value> dyn_dims;
auto getDynamicDims = [&](Value arg,
llvm::ArrayRef<int64_t> contracting_dims) {
RankedTensorType ty = arg.getType().cast<RankedTensorType>();
int index = 0;
for (auto contracting_dim : contracting_dims) {
for (; index < contracting_dim; index++) {
static_dims.push_back(ty.getDimSize(index));
dyn_dims.push_back(rewriter.create<mhlo::GetDimensionSizeOp>(
loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
rewriter.getI64IntegerAttr(index)));
}
index++;
}
for (; index < ty.getRank(); index++) {
static_dims.push_back(ty.getDimSize(index));
dyn_dims.push_back(rewriter.create<mhlo::GetDimensionSizeOp>(
loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg,
rewriter.getI64IntegerAttr(index)));
}
};
getDynamicDims(op.lhs(), lhs_contracting_dims);
getDynamicDims(op.rhs(), rhs_contracting_dims);
Value reshape_dims_tensor = rewriter.create<mhlo::ConcatenateOp>(
loc,
RankedTensorType::get({static_cast<int64_t>(dyn_dims.size())},
rewriter.getI32Type()),
dyn_dims, rewriter.getI64IntegerAttr(0));
Value result = rewriter.create<DynamicReshapeOp>(
op.getLoc(),
RankedTensorType::get(static_dims, result_ty.getElementType()),
new_dot_op, reshape_dims_tensor);
rewriter.replaceOp(op, result);
return success();
}