in lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc [32:128]
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
auto start_indices = gather.start_indices();
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
if (!start_indices_ty.hasRank()) {
return rewriter.notifyMatchFailure(gather, "unranked start_indices");
}
auto operand = gather.operand();
auto operand_ty = operand.getType().cast<ShapedType>();
if (!operand_ty.hasRank()) {
return rewriter.notifyMatchFailure(gather, "unranked operand");
}
int64_t index_vector_dim =
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
// We can use torch_index_select if the last dimension represents the
// gather indices.
auto dimension_numbers = gather.dimension_numbers();
if (dimension_numbers.getIndexVectorDim() != index_vector_dim) {
return rewriter.notifyMatchFailure(
gather, "index_vector_dim not last dimension of start_indices");
}
// Index select only works across a single dimension.
if (!start_indices_ty.getShape().empty() &&
start_indices_ty.getShape().back() != 1) {
return rewriter.notifyMatchFailure(
gather, "start_indices index vector dimension not 1");
}
// Only support the default case for start_index_map.
if (dimension_numbers.getStartIndexMap().size() != 1 ||
dimension_numbers.getStartIndexMap()[0] != 0) {
return rewriter.notifyMatchFailure(gather, "start_index_map != [0]");
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return rewriter.notifyMatchFailure(gather, "unranked result");
}
// Offset dimensions should be the defaults.
if (dimension_numbers.getOffsetDims().size() !=
result_ty.getRank() - index_vector_dim) {
return rewriter.notifyMatchFailure(
gather, "offset_dims.size not operand rank minus index_vector_dim");
}
for (const auto &it : llvm::enumerate(dimension_numbers.getOffsetDims())) {
if ((it.index() + index_vector_dim) != it.value()) {
return rewriter.notifyMatchFailure(
gather, "offset_dims != [index_vector_dim, result.rank)");
}
}
for (const auto &it :
llvm::enumerate(gather.slice_sizes().getValues<APInt>())) {
// First shape value must be 1.
if (it.index() == 0) {
if (it.value().getSExtValue() != 1) {
return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
return rewriter.notifyMatchFailure(
gather, "slice_size doesn't match operand dimension");
}
}
llvm::SmallVector<int64_t, 4> index_select_shape =
llvm::to_vector<4>(start_indices_ty.getShape());
for (auto dim : operand_ty.getShape().drop_front()) {
index_select_shape.push_back(dim);
}
if (dimension_numbers.getCollapsedSliceDims().size() != 1 ||
dimension_numbers.getCollapsedSliceDims()[0] != 0) {
return rewriter.notifyMatchFailure(gather, "collapsed_slice_dims != [0]");
}
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
gather.getLoc(),
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
torch_index_select);
return success();
}