in mqtt-streaming/src/main/scala/org/apache/pekko/stream/connectors/mqtt/streaming/scaladsl/MqttSession.scala [526:653]
override def commandFlow[A](connectionId: ByteString): CommandFlow[A] =
Flow
.lazyFutureFlow { () =>
val killSwitch = KillSwitches.shared("command-kill-switch-" + serverSessionId)
Future.successful(
Flow[Command[A]]
.watch(serverConnector.toClassic)
.watchTermination() {
case (_, terminated) =>
terminated.onComplete {
case Failure(_: WatchedActorTerminatedException) =>
case _ =>
serverConnector ! ServerConnector.ConnectionLost(connectionId)
}
NotUsed
}
.via(killSwitch.flow)
.flatMapMerge(
settings.commandParallelism,
{
case Command(cp: ConnAck, _, _) =>
val reply = Promise[Source[ClientConnection.ForwardConnAckCommand, NotUsed]]()
serverConnector ! ServerConnector.ConnAckReceivedLocally(connectionId, cp, reply)
Source.futureSource(
reply.future.map(_.map {
case ClientConnection.ForwardConnAck =>
cp.encode(ByteString.newBuilder).result()
case ClientConnection.ForwardPingResp =>
pingRespBytes
case ClientConnection.ForwardPublish(publish, packetId) =>
publish.encode(ByteString.newBuilder, packetId).result()
case ClientConnection.ForwardPubRel(packetId) =>
PubRel(packetId).encode(ByteString.newBuilder).result()
}.mapError {
case ServerConnector.PingFailed => ActorMqttServerSession.PingFailed
}
.watchTermination() { (_, done) =>
done.onComplete {
case Success(_) => killSwitch.shutdown()
case Failure(t) => killSwitch.abort(t)
}
}))
case Command(cp: SubAck, completed, _) =>
val reply = Promise[Publisher.ForwardSubAck.type]()
publisherPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId,
cp.packetId,
Publisher.SubAckReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source
.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result()))
.recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: UnsubAck, completed, _) =>
val reply = Promise[Unpublisher.ForwardUnsubAck.type]()
unpublisherPacketRouter ! RemotePacketRouter
.RouteViaConnection(connectionId, cp.packetId, Unpublisher.UnsubAckReceivedLocally(reply), reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: PubAck, completed, _) =>
val reply = Promise[Consumer.ForwardPubAck.type]()
consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId,
cp.packetId,
Consumer.PubAckReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: PubRec, completed, _) =>
val reply = Promise[Consumer.ForwardPubRec.type]()
consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId,
cp.packetId,
Consumer.PubRecReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: PubComp, completed, _) =>
val reply = Promise[Consumer.ForwardPubComp.type]()
consumerPacketRouter ! RemotePacketRouter.RouteViaConnection(connectionId,
cp.packetId,
Consumer.PubCompReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case c: Command[A] => throw new IllegalStateException(s"$c is not a server command")
})
.recover {
case _: WatchedActorTerminatedException => ByteString.empty
}
.filter(_.nonEmpty)
.log("server-commandFlow", _.iterator.decodeControlPacket(settings.maxPacketSize)) // we decode here so we can see the generated packet id
.withAttributes(ActorAttributes.logLevels(onFailure = Logging.DebugLevel)))
}