override def shape: FlowShape[WriteMessage, CommittableReadResult] = FlowShape.of()

in amqp/src/main/scala/org/apache/pekko/stream/connectors/amqp/impl/AmqpRpcFlowStage.scala [48:246]


  override def shape: FlowShape[WriteMessage, CommittableReadResult] = FlowShape.of(in, out)

  override protected def initialAttributes: Attributes =
    super.initialAttributes and Attributes.name("AmqpRpcFlow") and ActorAttributes.IODispatcher

  override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[String]) = {
    val streamCompletion = Promise[String]()
    (new GraphStageLogic(shape) with AmqpConnectorLogic {

        override val settings: AmqpWriteSettings = stage.writeSettings
        private val exchange = settings.exchange.getOrElse("")
        private val routingKey = settings.routingKey.getOrElse("")
        private val queue = mutable.Queue[CommittableReadResult]()
        private var queueName: String = _
        private var unackedMessages = 0
        private var outstandingMessages = 0

        override def whenConnected(): Unit = {

          pull(in)

          channel.basicQos(bufferSize)
          val consumerCallback = getAsyncCallback(handleDelivery)

          val commitCallback = getAsyncCallback[AckArguments] {
            case AckArguments(deliveryTag, multiple, promise) => {
              try {
                channel.basicAck(deliveryTag, multiple)
                unackedMessages -= 1
                if (unackedMessages == 0 && (isClosed(out) || (isClosed(
                    in) && queue.isEmpty && outstandingMessages == 0)))
                  completeStage()
                promise.complete(Success(Done))
              } catch {
                case e: Throwable => promise.failure(e)
              }
            }
          }
          val nackCallback = getAsyncCallback[NackArguments] {
            case NackArguments(deliveryTag, multiple, requeue, promise) => {
              try {
                channel.basicNack(deliveryTag, multiple, requeue)
                unackedMessages -= 1
                if (unackedMessages == 0 && (isClosed(out) || (isClosed(
                    in) && queue.isEmpty && outstandingMessages == 0)))
                  completeStage()
                promise.complete(Success(Done))
              } catch {
                case e: Throwable => promise.failure(e)
              }
            }
          }

          val amqpSourceConsumer = new DefaultConsumer(channel) {
            override def handleDelivery(consumerTag: String,
                envelope: Envelope,
                properties: BasicProperties,
                body: Array[Byte]): Unit =
              consumerCallback.invoke(
                new CommittableReadResult {
                  override val message = {
                    val byteString = if (settings.reuseByteArray)
                      ByteString.fromArrayUnsafe(body)
                    else
                      ByteString(body)
                    ReadResult(byteString, envelope, properties)
                  }

                  override def ack(multiple: Boolean): Future[Done] = {
                    val promise = Promise[Done]()
                    commitCallback.invoke(AckArguments(message.envelope.getDeliveryTag, multiple, promise))
                    promise.future
                  }

                  override def nack(multiple: Boolean, requeue: Boolean): Future[Done] = {
                    val promise = Promise[Done]()
                    nackCallback.invoke(NackArguments(message.envelope.getDeliveryTag, multiple, requeue, promise))
                    promise.future
                  }
                })

            override def handleCancel(consumerTag: String): Unit =
              // non consumer initiated cancel, for example happens when the queue has been deleted.
              shutdownCallback.invoke(
                new RuntimeException(s"Consumer $queueName with consumerTag $consumerTag shut down unexpectedly"))

            override def handleShutdownSignal(consumerTag: String, sig: ShutdownSignalException): Unit =
              // "Called when either the channel or the underlying connection has been shut down."
              shutdownCallback.invoke(
                new RuntimeException(s"Consumer $queueName with consumerTag $consumerTag shut down unexpectedly", sig))
          }

          // Create an exclusive queue with a randomly generated name for use as the replyTo portion of RPC
          queueName = channel
            .queueDeclare(
              "",
              false,
              true,
              true,
              Collections.emptyMap())
            .getQueue

          channel.basicConsume(
            queueName,
            amqpSourceConsumer)
          streamCompletion.success(queueName)
        }

        def handleDelivery(message: CommittableReadResult): Unit =
          if (isAvailable(out)) {
            pushMessage(message)
          } else if (queue.size + 1 > bufferSize) {
            onFailure(new RuntimeException(s"Reached maximum buffer size $bufferSize"))
          } else {
            queue.enqueue(message)
          }

        setHandler(
          out,
          new OutHandler {
            override def onPull(): Unit =
              if (queue.nonEmpty) {
                pushMessage(queue.dequeue())
              }

            override def onDownstreamFinish(cause: Throwable): Unit = {
              setKeepGoing(true)
              if (unackedMessages == 0) super.onDownstreamFinish(cause)
            }
          })

        def pushMessage(message: CommittableReadResult): Unit = {
          push(out, message)
          unackedMessages += 1
          outstandingMessages -= 1
        }

        setHandler(
          in,
          new InHandler {
            // We don't want to finish since we're still waiting
            // on incoming messages from rabbit. However, if we
            // haven't processed a message yet, we do want to complete
            // so that we don't hang.
            override def onUpstreamFinish(): Unit = {
              setKeepGoing(true)
              if (queue.isEmpty && outstandingMessages == 0 && unackedMessages == 0) super.onUpstreamFinish()
            }

            override def onUpstreamFailure(ex: Throwable): Unit = {
              setKeepGoing(true)
              if (queue.isEmpty && outstandingMessages == 0 && unackedMessages == 0)
                super.onUpstreamFailure(ex)
            }

            override def onPush(): Unit = {
              val elem = grab(in)
              val bytes = if (settings.reuseByteArray)
                elem.bytes.toArrayUnsafe()
              else
                elem.bytes.toArray
              val props = elem.properties.getOrElse(new BasicProperties()).builder.replyTo(queueName).build()
              channel.basicPublish(
                exchange,
                elem.routingKey.getOrElse(routingKey),
                elem.mandatory,
                elem.immediate,
                props,
                bytes)

              val expectedResponses: Int = {
                val headers = props.getHeaders
                if (headers == null) {
                  responsesPerMessage
                } else {
                  val r = headers.get("expectedReplies")
                  if (r != null) {
                    r.asInstanceOf[Int]
                  } else {
                    responsesPerMessage
                  }
                }
              }

              outstandingMessages += expectedResponses
              pull(in)
            }
          })
        override def postStop(): Unit = {
          streamCompletion.tryFailure(new RuntimeException("stage stopped unexpectedly"))
          super.postStop()
        }

        override def onFailure(ex: Throwable): Unit = {
          streamCompletion.tryFailure(ex)
          super.onFailure(ex)
        }
      }, streamCompletion.future)
  }