LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/optimize_mhlo.cc [57:162]


  LogicalResult matchAndRewrite(GatherOp gather,
                                PatternRewriter& rewriter) const override {
    auto dimension_numbers = gather.dimension_numbers();

    // Inputs need to be ranked to lower.
    if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
        !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
        !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
        !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
      return rewriter.notifyMatchFailure(gather,
                                         "non-static operand or start_indices");
    }

    if (dimension_numbers.getIndexVectorDim() != 0) {
      return rewriter.notifyMatchFailure(gather, "non-zero index_vector_dim");
    }

    // TODO(suderman): Handle start index map != {0}.
    if (dimension_numbers.getStartIndexMap().empty() ||
        dimension_numbers.getStartIndexMap().size() != 1 ||
        dimension_numbers.getStartIndexMap()[0] != 0) {
      return rewriter.notifyMatchFailure(gather,
                                         "start_index_map not empty or [0]");
    }

    auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();

    if (!result_ty) {
      return rewriter.notifyMatchFailure(gather, "unranked result");
    }
    if (dimension_numbers.getOffsetDims().size() != result_ty.getRank()) {
      return rewriter.notifyMatchFailure(gather,
                                         "offset_dims.size != operand.rank");
    }
    for (const auto& it : llvm::enumerate(dimension_numbers.getOffsetDims())) {
      if (it.index() != it.value()) {
        return rewriter.notifyMatchFailure(gather,
                                           "offset_dims != [0, result.rank)");
      }
    }

    if (gather.slice_sizes().size() <= result_ty.getRank()) {
      return rewriter.notifyMatchFailure(gather,
                                         "slices_size.size > result.rank");
    }

    for (const auto& it : llvm::enumerate(result_ty.getShape())) {
      if (gather.slice_sizes().getValues<int64_t>()[it.index() + 1] !=
          it.value()) {
        return failure();
      }
    }

    auto gather_start_indices = gather.start_indices();
    auto gather_start_indices_ty =
        gather_start_indices.getType().cast<ShapedType>();

    llvm::SmallVector<Value, 4> slice_start_indices;

    if (gather_start_indices_ty.getRank() == 0) {
      slice_start_indices.push_back(gather_start_indices);
    } else if (gather_start_indices_ty.getRank() == 1) {
      for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
        auto start = GetI64ElementsAttr({i}, &rewriter);
        auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
        auto stride = GetI64ElementsAttr({1}, &rewriter);
        auto indicesSlice = rewriter.create<SliceOp>(
            gather.getLoc(), gather_start_indices, start, limit, stride);
        auto reshaped = rewriter.create<ReshapeOp>(
            gather.getLoc(),
            RankedTensorType::get(
                {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
            indicesSlice);
        slice_start_indices.push_back(reshaped);
      }
    } else {
      return rewriter.notifyMatchFailure(gather, "start_indices.rank > 1");
    }

    auto sliceSizesTy = gather.slice_sizes().getType();

    // Start indices have implicit zeros when not specified. This is because
    // Gather occurs similar to slicing where full slices are inferred. Add any
    // missing zeros as necessary.
    auto zero = rewriter.create<ConstOp>(
        gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
                             {}, gather_start_indices_ty.getElementType())));
    while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
      slice_start_indices.push_back(zero);
    }

    SmallVector<int64_t, 5> sliceShape;
    for (auto shapeValue : gather.slice_sizes().getValues<APInt>()) {
      sliceShape.push_back(shapeValue.getSExtValue());
    }

    auto sliceTy =
        RankedTensorType::get(sliceShape, result_ty.getElementType());
    auto slice = rewriter.create<DynamicSliceOp>(
        gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
        gather.slice_sizes());

    rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);

    return success();
  }