in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2425:2541]
LogicalResult matchAndRewrite(
mhlo::TorchIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
int axis = static_cast<int>(op.dim());
int batch = static_cast<int>(op.batch_dims());
auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
int num_indices = static_cast<int>(index_shaped_type.getRank());
auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
if (batch < 0) batch += num_indices;
Location loc = op.getLoc();
auto result_type =
this->typeConverter->convertType(op.getResult().getType())
.cast<ShapedType>();
int rank = static_cast<int>(result_type.getRank());
// The output shape is
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
SmallVector<Value, 4> dyn_sizes;
for (int i = 0; i < rank; ++i) {
if (!result_type.isDynamicDim(i)) continue;
if (i < axis) {
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
} else if (i < (axis + num_indices - batch)) {
int idx = i - axis + batch;
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
} else {
int idx = i - (axis + num_indices - batch) + axis + 1;
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
}
}
// Generate dummy tensor to preserve slice shape information.
SmallVector<int64_t> slice_shape;
SmallVector<Value, 4> dyn_slice_sizes;
SmallVector<AffineExpr, 4> slice_exprs;
auto result_shape = result_type.getShape();
for (int i = 0; i < axis; ++i) {
slice_exprs.push_back(rewriter.getAffineDimExpr(i));
slice_shape.push_back(result_shape[i]);
if (!result_type.isDynamicDim(i)) continue;
dyn_slice_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
}
for (int i = axis + num_indices - batch; i < rank; ++i) {
slice_exprs.push_back(rewriter.getAffineDimExpr(i));
slice_shape.push_back(result_shape[i]);
if (!result_type.isDynamicDim(i)) continue;
int idx = i - (axis + num_indices - batch) + axis + 1;
dyn_slice_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
}
// Setup AffineMap for input tensor.
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < batch; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(i));
}
for (int i = 0, e = num_indices - batch; i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
}
SmallVector<AffineMap, 2> indexing_maps;
indexing_maps.emplace_back(
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
indexing_maps.emplace_back(AffineMap::get(
rank, /*symbolCount=*/0, slice_exprs, rewriter.getContext()));
indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
Value slice_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_slice_sizes, slice_shape, result_type.getElementType());
Value init_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
/*inputs=*/ValueRange{adaptor.index(), slice_op},
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank),
/*bodyBuild=*/nullptr, PruneAttributeList(op));
SmallVector<Type, 4> body_arg_types;
SmallVector<Value, 2> linalg_op_args = {adaptor.index(), slice_op};
// Add a block to the region.
auto* region = &linalg_op.region();
auto* block = rewriter.createBlock(region, region->end());
for (auto block_args : linalg_op_args) {
body_arg_types.push_back(
block_args.getType().cast<ShapedType>().getElementType());
}
block->addArguments(body_arg_types,
SmallVector<Location>(body_arg_types.size(), loc));
block->addArguments(result_type.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);
Value casted_value = rewriter.create<arith::IndexCastOp>(
loc, block->getArgument(0), rewriter.getIndexType());
SmallVector<Value, 4> indices;
for (int i = 0; i < axis; ++i) {
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
}
indices.push_back(casted_value);
for (int i = axis + num_indices - batch; i < rank; ++i) {
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
}
Value res =
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
rewriter.create<linalg::YieldOp>(loc, res);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}