in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2555:2714]
LogicalResult matchAndRewrite(
mhlo::GatherOp gatherOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
Location loc = gatherOp.getLoc();
Value startIndices = adaptor.start_indices();
Value operand = adaptor.operand();
RankedTensorType resultType =
gatherOp.getResult().getType().dyn_cast<RankedTensorType>();
RankedTensorType startIndicesType =
startIndices.getType().dyn_cast<RankedTensorType>();
// We could actually deal with an unranked result by inferring the result
// rank, but the current reifyReturnTypes doesn't support unranked either.
if (!resultType || !startIndicesType)
return rewriter.notifyMatchFailure(gatherOp,
"unranked start indices or result");
int resultRank = resultType.getRank();
// slice_sizes has to have the same size as operand.rank, and doing it this
// way permits an unranked operand.
int operandRank = gatherOp.slice_sizes().getNumElements();
int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim();
ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims();
ArrayRef<int64_t> collapsedSliceDims =
gatherOp.dimension_numbers().getCollapsedSliceDims();
ArrayRef<int64_t> startIndexMap =
gatherOp.dimension_numbers().getStartIndexMap();
auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value {
return rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(),
rewriter.create<tensor::ExtractOp>(loc, input, index));
};
// We'll need these later and creating them on demand we end up with
// duplicates, which also makes lit tests really hard to write.
SmallVector<Value> constants;
for (unsigned i = 0; i < std::max(resultRank, operandRank); ++i)
constants.push_back(
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));
// Create ops to calculate the dynamic dimensions of the return shape, which
// are needed for the init tensor.
SmallVector<Value> dynDimSizes;
if (!resultType.hasStaticShape()) {
SmallVector<Value> returnShapes;
if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(),
returnShapes)))
return rewriter.notifyMatchFailure(gatherOp,
"could not reify return shape");
assert(returnShapes.size() == 1);
Value returnShape = returnShapes[0];
for (int i = 0; i < resultRank; ++i)
if (resultType.isDynamicDim(i))
dynDimSizes.push_back(extractAsIndex(returnShape, constants[i]));
}
Value initOp = rewriter.create<linalg::InitTensorOp>(
loc, dynDimSizes, resultType.getShape(), resultType.getElementType());
ValueRange ins;
SmallVector<AffineMap, 1> indexingMaps(
{rewriter.getMultiDimIdentityMap(resultRank)});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/resultType,
/*inputs=*/ins,
/*outputs=*/initOp, indexingMaps, GetNParallelLoopsAttrs(resultRank),
/*bodyBuild=*/nullptr, PruneAttributeList(gatherOp));
// Now populate the linalg generic region
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(resultType.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);
// Dimensions in the result that aren't offset dimensions are called batch.
SmallVector<int64_t> batchDims;
for (int dim = 0; dim < resultRank; ++dim)
if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
// Same as with the constants. Creating these all up front is easier than
// potentially getting duplicates later.
SmallVector<Value> linalgIndices;
for (unsigned i = 0; i < resultRank; ++i)
linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
// Now the complicated part. For a given output dimension we build up an
// index into the input. It's composed of two parts: the index coming from
// start_indices, and the offset from that index along the offset
// dimensions. Everything includes dimension shuffling and remapping as well
// because of the way gather is defined to allow for any-layout input by
// adding more attributes.
// The base gather index (`G` in the documentation) points to a place in
// start_indices along the batch dimensions.
SmallVector<Value> gatherIndex;
for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]);
SmallVector<Value> indexFromStartIndices;
for (unsigned i = 0; i < startIndexMap.size(); ++i) {
// The index along the index_vector dimension of start_indices varies.
// Basically indexFromStartIndices indexes into a "row" along
// index_vector_dim, where the row is selected by the current output
// index.
// But if index_vector_dim is equal to start_indices.rank, then
// start_indices gets a trailing 1 dimension added. So the row we're
// extracting always has length 1 and the index into it is always 0, so we
// just use the gather index directly
SmallVector<Value> gCombine(gatherIndex);
if (indexVectorDim != startIndicesType.getRank()) {
assert(indexVectorDim <= gCombine.size());
gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
}
indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine));
}
// But then start indices are shuffled by the start index map. To make a
// full index into the operand, all missing indices are zeroes.
SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
for (auto& it : llvm::enumerate(startIndexMap))
remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];
// Now we construct the index based on the offset. First we need to remap
// the offset dimensions by dropping the collapsed indices.
SmallVector<unsigned> remappedOffsetDims;
for (unsigned i = 0; i < operandRank; ++i)
if (!llvm::is_contained(collapsedSliceDims, i))
remappedOffsetDims.push_back(i);
assert(remappedOffsetDims.size() == offsetDims.size());
// For the (remapped) offset dimensions, the index is the current index in
// the output. As before this is expanded to a full index into the operand
// by using zeroe for the missing indices.
SmallVector<Value> indexFromOffset(operandRank, constants[0]);
for (unsigned k = 0; k < offsetDims.size(); ++k)
indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];
// Now we add together our two indices to get the final index into the
// operand.
SmallVector<Value> combinedIndex;
for (unsigned i = 0; i < operandRank; ++i)
combinedIndex.push_back(rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
indexFromOffset[i]));
Value element =
rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex);
rewriter.create<linalg::YieldOp>(loc, element);
rewriter.replaceOp(gatherOp, linalgOp.getResults());
return success();
}