in lib/Dialect/mhlo/transforms/optimize_mhlo.cc [57:162]
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter& rewriter) const override {
auto dimension_numbers = gather.dimension_numbers();
// Inputs need to be ranked to lower.
if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
!gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
!gather.start_indices().getType().cast<ShapedType>().hasRank() ||
!gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
return rewriter.notifyMatchFailure(gather,
"non-static operand or start_indices");
}
if (dimension_numbers.getIndexVectorDim() != 0) {
return rewriter.notifyMatchFailure(gather, "non-zero index_vector_dim");
}
// TODO(suderman): Handle start index map != {0}.
if (dimension_numbers.getStartIndexMap().empty() ||
dimension_numbers.getStartIndexMap().size() != 1 ||
dimension_numbers.getStartIndexMap()[0] != 0) {
return rewriter.notifyMatchFailure(gather,
"start_index_map not empty or [0]");
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return rewriter.notifyMatchFailure(gather, "unranked result");
}
if (dimension_numbers.getOffsetDims().size() != result_ty.getRank()) {
return rewriter.notifyMatchFailure(gather,
"offset_dims.size != operand.rank");
}
for (const auto& it : llvm::enumerate(dimension_numbers.getOffsetDims())) {
if (it.index() != it.value()) {
return rewriter.notifyMatchFailure(gather,
"offset_dims != [0, result.rank)");
}
}
if (gather.slice_sizes().size() <= result_ty.getRank()) {
return rewriter.notifyMatchFailure(gather,
"slices_size.size > result.rank");
}
for (const auto& it : llvm::enumerate(result_ty.getShape())) {
if (gather.slice_sizes().getValues<int64_t>()[it.index() + 1] !=
it.value()) {
return failure();
}
}
auto gather_start_indices = gather.start_indices();
auto gather_start_indices_ty =
gather_start_indices.getType().cast<ShapedType>();
llvm::SmallVector<Value, 4> slice_start_indices;
if (gather_start_indices_ty.getRank() == 0) {
slice_start_indices.push_back(gather_start_indices);
} else if (gather_start_indices_ty.getRank() == 1) {
for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
auto start = GetI64ElementsAttr({i}, &rewriter);
auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
auto stride = GetI64ElementsAttr({1}, &rewriter);
auto indicesSlice = rewriter.create<SliceOp>(
gather.getLoc(), gather_start_indices, start, limit, stride);
auto reshaped = rewriter.create<ReshapeOp>(
gather.getLoc(),
RankedTensorType::get(
{}, indicesSlice.getType().cast<ShapedType>().getElementType()),
indicesSlice);
slice_start_indices.push_back(reshaped);
}
} else {
return rewriter.notifyMatchFailure(gather, "start_indices.rank > 1");
}
auto sliceSizesTy = gather.slice_sizes().getType();
// Start indices have implicit zeros when not specified. This is because
// Gather occurs similar to slicing where full slices are inferred. Add any
// missing zeros as necessary.
auto zero = rewriter.create<ConstOp>(
gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
{}, gather_start_indices_ty.getElementType())));
while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
slice_start_indices.push_back(zero);
}
SmallVector<int64_t, 5> sliceShape;
for (auto shapeValue : gather.slice_sizes().getValues<APInt>()) {
sliceShape.push_back(shapeValue.getSExtValue());
}
auto sliceTy =
RankedTensorType::get(sliceShape, result_ty.getElementType());
auto slice = rewriter.create<DynamicSliceOp>(
gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
gather.slice_sizes());
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
return success();
}