override protected def logic()

in core/src/main/scala/org/apache/pekko/kafka/internal/TransactionalSources.scala [219:281]


  override protected def logic(
      shape: SourceShape[(TopicPartition, Source[TransactionalMessage[K, V], NotUsed])])
      : GraphStageLogic with Control = {
    val factory = new SubSourceStageLogicFactory[K, V, TransactionalMessage[K, V]] {
      def create(
          shape: SourceShape[TransactionalMessage[K, V]],
          tp: TopicPartition,
          consumerActor: ActorRef,
          subSourceStartedCb: AsyncCallback[SubSourceStageLogicControl],
          subSourceCancelledCb: AsyncCallback[(TopicPartition, SubSourceCancellationStrategy)],
          actorNumber: Int): SubSourceStageLogic[K, V, TransactionalMessage[K, V]] =
        new TransactionalSubSourceStageLogic(shape,
          tp,
          consumerActor,
          subSourceStartedCb,
          subSourceCancelledCb,
          actorNumber,
          txConsumerSettings)
    }

    new SubSourceLogic(shape, txConsumerSettings, subscription, subSourceStageLogicFactory = factory) {

      override protected def addToPartitionAssignmentHandler(
          handler: PartitionAssignmentHandler): PartitionAssignmentHandler = {
        val blockingRevokedCall = new PartitionAssignmentHandler {
          override def onAssign(assignedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()

          // This is invoked in the KafkaConsumerActor thread when doing poll.
          override def onRevoke(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
            if (revokedTps.isEmpty) ()
            else if (waitForDraining(revokedTps)) {
              subSources.values
                .map(_.controlAndStageActor.stageActor)
                .foreach(_.tell(Revoked(revokedTps.toList), stageActor.ref))
            } else {
              sourceActor.ref.tell(Status.Failure(new Error("Timeout while draining")), stageActor.ref)
              consumerActor.tell(KafkaConsumerActor.Internal.StopFromStage(id), stageActor.ref)
            }

          override def onLost(lostTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit =
            onRevoke(lostTps, consumer)

          override def onStop(revokedTps: Set[TopicPartition], consumer: RestrictedConsumer): Unit = ()
        }
        new PartitionAssignmentHelpers.Chain(handler, blockingRevokedCall)
      }

      private def waitForDraining(partitions: Set[TopicPartition]): Boolean = {
        import pekko.pattern.ask
        implicit val timeout: Timeout = Timeout(txConsumerSettings.commitTimeout)
        try {
          val drainCommandFutures =
            subSources.values.map(_.stageActor).map(ask(_, Drain(partitions, None, Drained)))
          implicit val ec: ExecutionContext = executionContext
          Await.result(Future.sequence(drainCommandFutures), timeout.duration)
          true
        } catch {
          case t: Throwable =>
            false
        }
      }
    }
  }