LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/legalize_to_linalg.cc [2425:2541]


  LogicalResult matchAndRewrite(
      mhlo::TorchIndexSelectOp op, OpAdaptor adaptor,
      ConversionPatternRewriter& rewriter) const final {
    int axis = static_cast<int>(op.dim());
    int batch = static_cast<int>(op.batch_dims());
    auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
    int num_indices = static_cast<int>(index_shaped_type.getRank());
    auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
    if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
    if (batch < 0) batch += num_indices;

    Location loc = op.getLoc();
    auto result_type =
        this->typeConverter->convertType(op.getResult().getType())
            .cast<ShapedType>();
    int rank = static_cast<int>(result_type.getRank());

    // The output shape is
    //   `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
    SmallVector<Value, 4> dyn_sizes;
    for (int i = 0; i < rank; ++i) {
      if (!result_type.isDynamicDim(i)) continue;
      if (i < axis) {
        dyn_sizes.push_back(
            rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
      } else if (i < (axis + num_indices - batch)) {
        int idx = i - axis + batch;
        dyn_sizes.push_back(
            rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
      } else {
        int idx = i - (axis + num_indices - batch) + axis + 1;
        dyn_sizes.push_back(
            rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
      }
    }

    // Generate dummy tensor to preserve slice shape information.
    SmallVector<int64_t> slice_shape;
    SmallVector<Value, 4> dyn_slice_sizes;
    SmallVector<AffineExpr, 4> slice_exprs;
    auto result_shape = result_type.getShape();
    for (int i = 0; i < axis; ++i) {
      slice_exprs.push_back(rewriter.getAffineDimExpr(i));
      slice_shape.push_back(result_shape[i]);
      if (!result_type.isDynamicDim(i)) continue;
      dyn_slice_sizes.push_back(
          rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
    }
    for (int i = axis + num_indices - batch; i < rank; ++i) {
      slice_exprs.push_back(rewriter.getAffineDimExpr(i));
      slice_shape.push_back(result_shape[i]);
      if (!result_type.isDynamicDim(i)) continue;
      int idx = i - (axis + num_indices - batch) + axis + 1;
      dyn_slice_sizes.push_back(
          rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
    }

    // Setup AffineMap for input tensor.
    SmallVector<AffineExpr, 4> exprs;
    for (int i = 0; i < batch; ++i) {
      exprs.push_back(rewriter.getAffineDimExpr(i));
    }
    for (int i = 0, e = num_indices - batch; i < e; ++i) {
      exprs.push_back(rewriter.getAffineDimExpr(axis + i));
    }

    SmallVector<AffineMap, 2> indexing_maps;
    indexing_maps.emplace_back(
        AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
    indexing_maps.emplace_back(AffineMap::get(
        rank, /*symbolCount=*/0, slice_exprs, rewriter.getContext()));
    indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));

    Value slice_op = rewriter.create<linalg::InitTensorOp>(
        loc, dyn_slice_sizes, slice_shape, result_type.getElementType());

    Value init_op = rewriter.create<linalg::InitTensorOp>(
        loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
    auto linalg_op = rewriter.create<linalg::GenericOp>(
        loc, /*resultTensors=*/ArrayRef<Type>{result_type},
        /*inputs=*/ValueRange{adaptor.index(), slice_op},
        /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank),
        /*bodyBuild=*/nullptr, PruneAttributeList(op));

    SmallVector<Type, 4> body_arg_types;
    SmallVector<Value, 2> linalg_op_args = {adaptor.index(), slice_op};
    // Add a block to the region.
    auto* region = &linalg_op.region();
    auto* block = rewriter.createBlock(region, region->end());
    for (auto block_args : linalg_op_args) {
      body_arg_types.push_back(
          block_args.getType().cast<ShapedType>().getElementType());
    }
    block->addArguments(body_arg_types,
                        SmallVector<Location>(body_arg_types.size(), loc));
    block->addArguments(result_type.getElementType(), loc);
    OpBuilder::InsertionGuard guard(rewriter);
    rewriter.setInsertionPointToEnd(block);

    Value casted_value = rewriter.create<arith::IndexCastOp>(
        loc, block->getArgument(0), rewriter.getIndexType());

    SmallVector<Value, 4> indices;
    for (int i = 0; i < axis; ++i) {
      indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
    }
    indices.push_back(casted_value);
    for (int i = axis + num_indices - batch; i < rank; ++i) {
      indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
    }
    Value res =
        rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
    rewriter.create<linalg::YieldOp>(loc, res);

    rewriter.replaceOp(op, linalg_op.getResults());
    return success();
  }