in jakartams/src/main/scala/org/apache/pekko/stream/connectors/jakartams/impl/JmsProducerStage.scala [78:261]
private def producerLogic(inheritedAttributes: Attributes) =
new TimerGraphStageLogic(shape) with JmsProducerConnector with GraphStageCompanion with StageLogging {
final override def graphStageMaterializer: Materializer = materializer
final override def graphStageDestination: Destination = destination
final override def scheduleOnceOnGraphStage(timerKey: Any, delay: FiniteDuration): Unit =
scheduleOnce(timerKey, delay)
final override def isTimerActiveOnGraphStage(timerKey: Any): Boolean = isTimerActive(timerKey)
final override def cancelTimerOnGraphStage(timerKey: Any): Unit = cancelTimer(timerKey)
/*
* NOTE: the following code is heavily inspired by org.apache.pekko.stream.impl.fusing.MapAsync
*
* To get a condensed view of what the buffers and handler behavior is about, have a look there too.
*/
private lazy val decider = inheritedAttributes.mandatoryAttribute[SupervisionStrategy].decider
// the current connection epoch. Reconnects increment this epoch by 1.
private var currentJmsProducerEpoch = 0
// available producers for sending messages. Initially full, but might contain less elements if
// messages are currently in-flight.
private val jmsProducers: Buffer[JmsMessageProducer] = Buffer(settings.sessionCount, settings.sessionCount)
// in-flight messages with the producers that were used to send them.
private val inFlightMessages: Buffer[Holder[E]] =
Buffer(settings.sessionCount, settings.sessionCount)
protected val destination: Destination = stage.destination
protected val jmsSettings: JmsProducerSettings = settings
override def preStart(): Unit = {
ec = executionContext(inheritedAttributes)
super.preStart()
initSessionAsync()
}
override protected def onSessionOpened(jmsSession: JmsProducerSession): Unit =
sessionOpened(Try {
jmsProducers.enqueue(JmsMessageProducer(jmsSession, settings, currentJmsProducerEpoch))
// startup situation: while producer pool was empty, the out port might have pulled. If so, pull from in port.
// Note that a message might be already in-flight; that's fine since this stage pre-fetches message from
// upstream anyway to increase throughput once the stream is started.
if (isAvailable(out)) pullIfNeeded()
})
override protected def connectionFailed(ex: Throwable): Unit = {
jmsProducers.clear()
currentJmsProducerEpoch += 1
super.connectionFailed(ex)
}
setHandler(out,
new OutHandler {
override def onPull(): Unit = pushNextIfPossible()
override def onDownstreamFinish(cause: Throwable): Unit = publishAndCompleteStage()
})
setHandler(
in,
new InHandler {
override def onUpstreamFinish(): Unit = if (inFlightMessages.isEmpty) publishAndCompleteStage()
override def onUpstreamFailure(ex: Throwable): Unit = {
publishAndFailStage(ex)
}
override def onPush(): Unit = {
val elem: E = grab(in)
elem match {
case _: JmsPassThrough[_] =>
val holder = new Holder[E](NotYetThere)
inFlightMessages.enqueue(holder)
holder(Success(elem))
pushNextIfPossible()
case m: JmsEnvelope[_] =>
// create a holder object to capture the in-flight message, and enqueue it to preserve message order
val holder = new Holder[E](NotYetThere)
inFlightMessages.enqueue(holder)
sendWithRetries(SendAttempt(m.asInstanceOf[E], holder))
case other =>
log.warning("unhandled element []", other)
}
// immediately ask for the next element if producers are available.
pullIfNeeded()
}
})
private def publishAndCompleteStage(): Unit = {
val previous = updateState(InternalConnectionState.JmsConnectorStopping(Success(Done)))
closeSessions()
closeConnectionAsync(JmsConnector.connection(previous))
completeStage()
}
override def onTimer(timerKey: Any): Unit = timerKey match {
case s: SendAttempt[E @unchecked] => sendWithRetries(s)
case _ => super.onTimer(timerKey)
}
private def sendWithRetries(send: SendAttempt[E]): Unit = {
import send._
if (jmsProducers.nonEmpty) {
val jmsProducer: JmsMessageProducer = jmsProducers.dequeue()
Future(jmsProducer.send(envelope)).andThen {
case tried => sendCompletedCB.invoke((send, tried, jmsProducer))
}
} else {
nextTryOrFail(send, RetrySkippedOnMissingConnection)
}
}
def nextTryOrFail(send: SendAttempt[E], ex: Throwable): Unit = {
import send._
import settings.sendRetrySettings._
if (maxRetries < 0 || attempt + 1 <= maxRetries) {
val nextAttempt = attempt + 1
val delay = if (backoffMaxed) maxBackoff else waitTime(nextAttempt)
val backoffNowMaxed = backoffMaxed || delay == maxBackoff
scheduleOnce(send.copy(attempt = nextAttempt, backoffMaxed = backoffNowMaxed), delay)
} else {
holder(Failure(ex))
handleFailure(ex, holder)
}
}
private val sendCompletedCB = getAsyncCallback[(SendAttempt[E], Try[Unit], JmsMessageProducer)] {
case (send, outcome, jmsProducer) =>
// same epoch indicates that the producer belongs to the current alive connection.
if (jmsProducer.epoch == currentJmsProducerEpoch) jmsProducers.enqueue(jmsProducer)
import send._
outcome match {
case Success(_) =>
holder(Success(send.envelope))
pushNextIfPossible()
case Failure(t: jms.JMSException) =>
nextTryOrFail(send, t)
case Failure(t) =>
holder(Failure(t))
handleFailure(t, holder)
}
}
override def postStop(): Unit = finishStop()
private def pullIfNeeded(): Unit =
if (jmsProducers.nonEmpty // only pull if a producer is available in the pool.
&& !inFlightMessages.isFull // and a place is available in the in-flight queue.
&& !hasBeenPulled(in))
tryPull(in)
private def pushNextIfPossible(): Unit =
if (inFlightMessages.isEmpty) {
// no messages in flight, are we about to complete?
if (isClosed(in)) publishAndCompleteStage() else pullIfNeeded()
} else if (inFlightMessages.peek().elem eq NotYetThere) {
// next message to be produced is still not there, we need to wait.
pullIfNeeded()
} else if (isAvailable(out)) {
val holder = inFlightMessages.dequeue()
holder.elem match {
case Success(elem) =>
push(out, elem)
pullIfNeeded() // Ask for the next element.
case Failure(ex) => handleFailure(ex, holder)
}
}
private def handleFailure(ex: Throwable, holder: Holder[E]): Unit =
holder.supervisionDirectiveFor(decider, ex) match {
case Supervision.Stop => failStage(ex) // fail only if supervision asks for it.
case _ => pushNextIfPossible()
}
}