LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2555:2714]


  LogicalResult matchAndRewrite(
      mhlo::GatherOp gatherOp, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const final {
    Location loc = gatherOp.getLoc();

    Value startIndices = adaptor.start_indices();
    Value operand = adaptor.operand();

    RankedTensorType resultType =
        gatherOp.getResult().getType().dyn_cast<RankedTensorType>();
    RankedTensorType startIndicesType =
        startIndices.getType().dyn_cast<RankedTensorType>();
    // We could actually deal with an unranked result by inferring the result
    // rank, but the current reifyReturnTypes doesn't support unranked either.
    if (!resultType || !startIndicesType)
      return rewriter.notifyMatchFailure(gatherOp,
                                         "unranked start indices or result");

    int resultRank = resultType.getRank();
    // slice_sizes has to have the same size as operand.rank, and doing it this
    // way permits an unranked operand.
    int operandRank = gatherOp.slice_sizes().getNumElements();

    int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim();

    ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims();
    ArrayRef<int64_t> collapsedSliceDims =
        gatherOp.dimension_numbers().getCollapsedSliceDims();
    ArrayRef<int64_t> startIndexMap =
        gatherOp.dimension_numbers().getStartIndexMap();

    auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value {
      return rewriter.create<arith::IndexCastOp>(
          loc, rewriter.getIndexType(),
          rewriter.create<tensor::ExtractOp>(loc, input, index));
    };

    // We'll need these later and creating them on demand we end up with
    // duplicates, which also makes lit tests really hard to write.
    SmallVector<Value> constants;
    for (unsigned i = 0; i < std::max(resultRank, operandRank); ++i)
      constants.push_back(
          rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));

    // Create ops to calculate the dynamic dimensions of the return shape, which
    // are needed for the init tensor.
    SmallVector<Value> dynDimSizes;
    if (!resultType.hasStaticShape()) {
      SmallVector<Value> returnShapes;
      if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(),
                                                returnShapes)))
        return rewriter.notifyMatchFailure(gatherOp,
                                           "could not reify return shape");
      assert(returnShapes.size() == 1);
      Value returnShape = returnShapes[0];

      for (int i = 0; i < resultRank; ++i)
        if (resultType.isDynamicDim(i))
          dynDimSizes.push_back(extractAsIndex(returnShape, constants[i]));
    }

    Value initOp = rewriter.create<linalg::InitTensorOp>(
        loc, dynDimSizes, resultType.getShape(), resultType.getElementType());

    ValueRange ins;
    SmallVector<AffineMap, 1> indexingMaps(
        {rewriter.getMultiDimIdentityMap(resultRank)});
    auto linalgOp = rewriter.create<linalg::GenericOp>(
        loc, /*resultTensorTypes=*/resultType,
        /*inputs=*/ins,
        /*outputs=*/initOp, indexingMaps, GetNParallelLoopsAttrs(resultRank),
        /*bodyBuild=*/nullptr, PruneAttributeList(gatherOp));

    // Now populate the linalg generic region
    auto* region = &linalgOp.region();
    auto* block = rewriter.createBlock(region, region->end());
    block->addArguments(resultType.getElementType(), loc);
    OpBuilder::InsertionGuard guard(rewriter);
    rewriter.setInsertionPointToEnd(block);

    // Dimensions in the result that aren't offset dimensions are called batch.
    SmallVector<int64_t> batchDims;
    for (int dim = 0; dim < resultRank; ++dim)
      if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);

    // Same as with the constants. Creating these all up front is easier than
    // potentially getting duplicates later.
    SmallVector<Value> linalgIndices;
    for (unsigned i = 0; i < resultRank; ++i)
      linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));

    // Now the complicated part. For a given output dimension we build up an
    // index into the input. It's composed of two parts: the index coming from
    // start_indices, and the offset from that index along the offset
    // dimensions. Everything includes dimension shuffling and remapping as well
    // because of the way gather is defined to allow for any-layout input by
    // adding more attributes.

    // The base gather index (`G` in the documentation) points to a place in
    // start_indices along the batch dimensions.
    SmallVector<Value> gatherIndex;
    for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]);

    SmallVector<Value> indexFromStartIndices;
    for (unsigned i = 0; i < startIndexMap.size(); ++i) {
      // The index along the index_vector dimension of start_indices varies.
      // Basically indexFromStartIndices indexes into a "row" along
      // index_vector_dim, where the row is selected by the current output
      // index.
      // But if index_vector_dim is equal to start_indices.rank, then
      // start_indices gets a trailing 1 dimension added. So the row we're
      // extracting always has length 1 and the index into it is always 0, so we
      // just use the gather index directly
      SmallVector<Value> gCombine(gatherIndex);
      if (indexVectorDim != startIndicesType.getRank()) {
        assert(indexVectorDim <= gCombine.size());
        gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
      }

      indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine));
    }

    // But then start indices are shuffled by the start index map. To make a
    // full index into the operand, all missing indices are zeroes.
    SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
    for (auto& it : llvm::enumerate(startIndexMap))
      remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];

    // Now we construct the index based on the offset. First we need to remap
    // the offset dimensions by dropping the collapsed indices.
    SmallVector<unsigned> remappedOffsetDims;
    for (unsigned i = 0; i < operandRank; ++i)
      if (!llvm::is_contained(collapsedSliceDims, i))
        remappedOffsetDims.push_back(i);

    assert(remappedOffsetDims.size() == offsetDims.size());

    // For the (remapped) offset dimensions, the index is the current index in
    // the output. As before this is expanded to a full index into the operand
    // by using zeroe for the missing indices.
    SmallVector<Value> indexFromOffset(operandRank, constants[0]);
    for (unsigned k = 0; k < offsetDims.size(); ++k)
      indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];

    // Now we add together our two indices to get the final index into the
    // operand.
    SmallVector<Value> combinedIndex;
    for (unsigned i = 0; i < operandRank; ++i)
      combinedIndex.push_back(rewriter.create<arith::AddIOp>(
          loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
          indexFromOffset[i]));

    Value element =
        rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex);
    rewriter.create<linalg::YieldOp>(loc, element);

    rewriter.replaceOp(gatherOp, linalgOp.getResults());

    return success();
  }