private def producerLogic()

in jakartams/src/main/scala/org/apache/pekko/stream/connectors/jakartams/impl/JmsProducerStage.scala [78:261]


  private def producerLogic(inheritedAttributes: Attributes) =
    new TimerGraphStageLogic(shape) with JmsProducerConnector with GraphStageCompanion with StageLogging {

      final override def graphStageMaterializer: Materializer = materializer

      final override def graphStageDestination: Destination = destination

      final override def scheduleOnceOnGraphStage(timerKey: Any, delay: FiniteDuration): Unit =
        scheduleOnce(timerKey, delay)

      final override def isTimerActiveOnGraphStage(timerKey: Any): Boolean = isTimerActive(timerKey)

      final override def cancelTimerOnGraphStage(timerKey: Any): Unit = cancelTimer(timerKey)

      /*
       * NOTE: the following code is heavily inspired by org.apache.pekko.stream.impl.fusing.MapAsync
       *
       * To get a condensed view of what the buffers and handler behavior is about, have a look there too.
       */

      private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider

      // the current connection epoch. Reconnects increment this epoch by 1.
      private var currentJmsProducerEpoch = 0

      // available producers for sending messages. Initially full, but might contain less elements if
      // messages are currently in-flight.
      private val jmsProducers: Buffer[JmsMessageProducer] = Buffer(settings.sessionCount, settings.sessionCount)

      // in-flight messages with the producers that were used to send them.
      private val inFlightMessages: Buffer[Holder[E]] =
        Buffer(settings.sessionCount, settings.sessionCount)

      protected val destination: Destination = stage.destination
      protected val jmsSettings: JmsProducerSettings = settings

      override def preStart(): Unit = {
        ec = executionContext(inheritedAttributes)
        super.preStart()
        initSessionAsync()
      }

      override protected def onSessionOpened(jmsSession: JmsProducerSession): Unit =
        sessionOpened(Try {
          jmsProducers.enqueue(JmsMessageProducer(jmsSession, settings, currentJmsProducerEpoch))
          // startup situation: while producer pool was empty, the out port might have pulled. If so, pull from in port.
          // Note that a message might be already in-flight; that's fine since this stage pre-fetches message from
          // upstream anyway to increase throughput once the stream is started.
          if (isAvailable(out)) pullIfNeeded()
        })

      override protected def connectionFailed(ex: Throwable): Unit = {
        jmsProducers.clear()
        currentJmsProducerEpoch += 1
        super.connectionFailed(ex)
      }

      setHandler(out,
        new OutHandler {
          override def onPull(): Unit = pushNextIfPossible()

          override def onDownstreamFinish(cause: Throwable): Unit = publishAndCompleteStage()
        })

      setHandler(
        in,
        new InHandler {
          override def onUpstreamFinish(): Unit = if (inFlightMessages.isEmpty) publishAndCompleteStage()

          override def onUpstreamFailure(ex: Throwable): Unit = {
            publishAndFailStage(ex)
          }

          override def onPush(): Unit = {
            val elem: E = grab(in)
            elem match {
              case _: JmsPassThrough[_] =>
                val holder = new Holder[E](NotYetThere)
                inFlightMessages.enqueue(holder)
                holder(Success(elem))
                pushNextIfPossible()
              case m: JmsEnvelope[_] =>
                // create a holder object to capture the in-flight message, and enqueue it to preserve message order
                val holder = new Holder[E](NotYetThere)
                inFlightMessages.enqueue(holder)
                sendWithRetries(SendAttempt(m.asInstanceOf[E], holder))
              case other =>
                log.warning("unhandled element []", other)
            }

            // immediately ask for the next element if producers are available.
            pullIfNeeded()
          }
        })

      private def publishAndCompleteStage(): Unit = {
        val previous = updateState(InternalConnectionState.JmsConnectorStopping(Success(Done)))
        closeSessions()
        closeConnectionAsync(JmsConnector.connection(previous))
        completeStage()
      }

      override def onTimer(timerKey: Any): Unit = timerKey match {
        case s: SendAttempt[E @unchecked] => sendWithRetries(s)
        case _                            => super.onTimer(timerKey)
      }

      private def sendWithRetries(send: SendAttempt[E]): Unit = {
        import send._
        if (jmsProducers.nonEmpty) {
          val jmsProducer: JmsMessageProducer = jmsProducers.dequeue()
          Future(jmsProducer.send(envelope)).andThen {
            case tried => sendCompletedCB.invoke((send, tried, jmsProducer))
          }
        } else {
          nextTryOrFail(send, RetrySkippedOnMissingConnection)
        }
      }

      def nextTryOrFail(send: SendAttempt[E], ex: Throwable): Unit = {
        import send._
        import settings.sendRetrySettings._
        if (maxRetries < 0 || attempt + 1 <= maxRetries) {
          val nextAttempt = attempt + 1
          val delay = if (backoffMaxed) maxBackoff else waitTime(nextAttempt)
          val backoffNowMaxed = backoffMaxed || delay == maxBackoff
          scheduleOnce(send.copy(attempt = nextAttempt, backoffMaxed = backoffNowMaxed), delay)
        } else {
          holder(Failure(ex))
          handleFailure(ex, holder)
        }
      }

      private val sendCompletedCB = getAsyncCallback[(SendAttempt[E], Try[Unit], JmsMessageProducer)] {
        case (send, outcome, jmsProducer) =>
          // same epoch indicates that the producer belongs to the current alive connection.
          if (jmsProducer.epoch == currentJmsProducerEpoch) jmsProducers.enqueue(jmsProducer)

          import send._

          outcome match {
            case Success(_) =>
              holder(Success(send.envelope))
              pushNextIfPossible()
            case Failure(t: jms.JMSException) =>
              nextTryOrFail(send, t)
            case Failure(t) =>
              holder(Failure(t))
              handleFailure(t, holder)
          }
      }

      override def postStop(): Unit = finishStop()

      private def pullIfNeeded(): Unit =
        if (jmsProducers.nonEmpty // only pull if a producer is available in the pool.
          && !inFlightMessages.isFull // and a place is available in the in-flight queue.
          && !hasBeenPulled(in))
          tryPull(in)

      private def pushNextIfPossible(): Unit =
        if (inFlightMessages.isEmpty) {
          // no messages in flight, are we about to complete?
          if (isClosed(in)) publishAndCompleteStage() else pullIfNeeded()
        } else if (inFlightMessages.peek().elem eq NotYetThere) {
          // next message to be produced is still not there, we need to wait.
          pullIfNeeded()
        } else if (isAvailable(out)) {
          val holder = inFlightMessages.dequeue()
          holder.elem match {
            case Success(elem) =>
              push(out, elem)
              pullIfNeeded() // Ask for the next element.

            case Failure(ex) => handleFailure(ex, holder)
          }
        }

      private def handleFailure(ex: Throwable, holder: Holder[E]): Unit =
        holder.supervisionDirectiveFor(decider, ex) match {
          case Supervision.Stop => failStage(ex) // fail only if supervision asks for it.
          case _                => pushNextIfPossible()
        }
    }