void runOnOperation()

in lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc [57:197]


  void runOnOperation() override {
    auto func = getOperation();

    // TODO(pifon): Remove assumption that the function has a single block.
    if (!llvm::hasSingleElement(func)) {
      emitError(func.getLoc(), "The function needs to have a single block.");
      signalPassFailure();
      return;
    }

    // The fusion in Linalg is currently possible only when the consumer op is
    // tiled. In order to greedily fuse the ops, we have to start from the tiled
    // root linalg ops, i.e. linalg ops that write to output buffers of the
    // function or are returned in case of escaping allocations.
    llvm::SmallDenseSet<Value> result_buffers;
    for (auto func_arg : func.getArguments()) {
      result_buffers.insert(func_arg);
    }
    for (auto& block : func) {
      auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
      if (!returnOp) continue;
      for (auto operand : returnOp.getOperands()) {
        result_buffers.insert(operand);
      }
    }
    // Resolve aliasing operations (like casts) on the result to identify
    // results. This only handles escaping results.
    // TODO(herhut): Use BufferizeAliasAnalysis for this.
    llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
                                         result_buffers.end());
    while (!worklist.empty()) {
      Value result = worklist.pop_back_val();
      auto* definingOp = result.getDefiningOp();
      if (!definingOp) {
        continue;
      }

      if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
        auto alias = viewLike.getViewSource();
        if (result_buffers.insert(alias).second) {
          worklist.push_back(alias);
        }
        continue;
      }

      if (auto to_tensor = dyn_cast<bufferization::ToTensorOp>(definingOp)) {
        auto alias = to_tensor.memref();
        if (result_buffers.insert(alias).second) {
          worklist.push_back(alias);
        }
        continue;
      }

      if (auto to_memref = dyn_cast<bufferization::ToMemrefOp>(definingOp)) {
        auto alias = to_memref.tensor();
        if (result_buffers.insert(alias).second) {
          worklist.push_back(alias);
        }
        continue;
      }

      if (auto tensor_cast = dyn_cast<tensor::CastOp>(definingOp)) {
        auto alias = tensor_cast.source();
        if (result_buffers.insert(alias).second) {
          worklist.push_back(alias);
        }
        continue;
      }

      if (auto regionInterface =
              dyn_cast<RegionBranchOpInterface>(definingOp)) {
        for (Region& region : regionInterface.getOperation()->getRegions()) {
          // Only consider regions that can return to the parent region.
          SmallVector<RegionSuccessor, 2> successorRegions;
          regionInterface.getSuccessorRegions(region.getRegionNumber(),
                                              successorRegions);
          if (llvm::none_of(successorRegions, [&](auto successorRegion) {
                return successorRegion.isParent();
              }))
            continue;

          // Iterate over all immediate terminators and record the values
          // corresponding to result_buffers of interest.
          for (Block& block : region) {
            if (block.empty()) continue;
            Operation& operation = block.back();
            if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
            auto idx = result.dyn_cast<OpResult>().getResultNumber();
            if (result_buffers.insert(operation.getOperand(idx)).second) {
              worklist.push_back(operation.getOperand(idx));
            }
          }
        }
      }
    }

    MLIRContext* ctx = func.getContext();
    OpBuilder b(func);
    func.walk([&](linalg::GenericOp generic_op) {
      SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
                                         tile_sizes_.end());
      if (tile_sizes.empty()) {
        tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
      }
      auto op = cast<LinalgOp>(generic_op.getOperation());
      for (OpOperand* op_operand : op.getOutputBufferOperands()) {
        if (!result_buffers.count(op_operand->get())) continue;
        if (tileGenericOp(op, tile_sizes, &b)) {
          generic_op.erase();
          return;
        }
      }
    });
    auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
    if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
      return signalPassFailure();

    // Fuse producers of tiled linalg ops.
    llvm::SmallDenseSet<Operation*> erase_set;
    SmallVector<LinalgOp, 8> linalg_ops;
    func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
    for (LinalgOp op : llvm::reverse(linalg_ops)) {
      for (OpOperand* inputOperand : op.getInputOperands()) {
        linalg::Aliases aliases;
        linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
        auto info = fuseProducerOfBuffer(b, *inputOperand, graph);
        if (failed(info)) continue;
        auto* originalOp = info->originalProducer.getOperation();
        erase_set.insert(originalOp);
        auto* originalOpInLinalgOpsVector =
            std::find_if(linalg_ops.begin(), linalg_ops.end(),
                         [&](const Operation* op) { return op == originalOp; });
        *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
      }

      auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
      if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
        return signalPassFailure();
    }
    for (auto* e : erase_set) e->erase();
  }