LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc [32:128]


  LogicalResult matchAndRewrite(GatherOp gather,
                                PatternRewriter &rewriter) const override {
    auto start_indices = gather.start_indices();
    auto start_indices_ty = start_indices.getType().cast<ShapedType>();
    if (!start_indices_ty.hasRank()) {
      return rewriter.notifyMatchFailure(gather, "unranked start_indices");
    }

    auto operand = gather.operand();
    auto operand_ty = operand.getType().cast<ShapedType>();
    if (!operand_ty.hasRank()) {
      return rewriter.notifyMatchFailure(gather, "unranked operand");
    }

    int64_t index_vector_dim =
        std::max<int64_t>(0, start_indices_ty.getRank() - 1);

    // We can use torch_index_select if the last dimension represents the
    // gather indices.
    auto dimension_numbers = gather.dimension_numbers();
    if (dimension_numbers.getIndexVectorDim() != index_vector_dim) {
      return rewriter.notifyMatchFailure(
          gather, "index_vector_dim not last dimension of start_indices");
    }

    // Index select only works across a single dimension.
    if (!start_indices_ty.getShape().empty() &&
        start_indices_ty.getShape().back() != 1) {
      return rewriter.notifyMatchFailure(
          gather, "start_indices index vector dimension not 1");
    }

    // Only support the default case for start_index_map.
    if (dimension_numbers.getStartIndexMap().size() != 1 ||
        dimension_numbers.getStartIndexMap()[0] != 0) {
      return rewriter.notifyMatchFailure(gather, "start_index_map != [0]");
    }

    auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
    if (!result_ty) {
      return rewriter.notifyMatchFailure(gather, "unranked result");
    }

    // Offset dimensions should be the defaults.
    if (dimension_numbers.getOffsetDims().size() !=
        result_ty.getRank() - index_vector_dim) {
      return rewriter.notifyMatchFailure(
          gather, "offset_dims.size not operand rank minus index_vector_dim");
    }

    for (const auto &it : llvm::enumerate(dimension_numbers.getOffsetDims())) {
      if ((it.index() + index_vector_dim) != it.value()) {
        return rewriter.notifyMatchFailure(
            gather, "offset_dims != [index_vector_dim, result.rank)");
      }
    }

    for (const auto &it :
         llvm::enumerate(gather.slice_sizes().getValues<APInt>())) {
      // First shape value must be 1.
      if (it.index() == 0) {
        if (it.value().getSExtValue() != 1) {
          return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
        }
        continue;
      }

      // The gather needs to index the entire slice for each other dimension.
      if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
        return rewriter.notifyMatchFailure(
            gather, "slice_size doesn't match operand dimension");
      }
    }

    llvm::SmallVector<int64_t, 4> index_select_shape =
        llvm::to_vector<4>(start_indices_ty.getShape());

    for (auto dim : operand_ty.getShape().drop_front()) {
      index_select_shape.push_back(dim);
    }

    if (dimension_numbers.getCollapsedSliceDims().size() != 1 ||
        dimension_numbers.getCollapsedSliceDims()[0] != 0) {
      return rewriter.notifyMatchFailure(gather, "collapsed_slice_dims != [0]");
    }

    auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
        gather.getLoc(),
        RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
        operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
        rewriter.getI64IntegerAttr(0));

    rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
                                           torch_index_select);

    return success();
  }