in mqtt-streaming/src/main/scala/org/apache/pekko/stream/connectors/mqtt/streaming/scaladsl/MqttSession.scala [190:300]
private[streaming] override def commandFlow[A](connectionId: ByteString): CommandFlow[A] =
Flow
.lazyFutureFlow { () =>
val killSwitch = KillSwitches.shared("command-kill-switch-" + clientSessionId)
Future.successful(
Flow[Command[A]]
.watch(clientConnector.toClassic)
.watchTermination() {
case (_, terminated) =>
terminated.onComplete {
case Failure(_: WatchedActorTerminatedException) =>
case _ =>
clientConnector ! ClientConnector.ConnectionLost(connectionId)
}
NotUsed
}
.via(killSwitch.flow)
.flatMapMerge(
settings.commandParallelism,
{
case Command(cp: Connect, _, carry) =>
val reply = Promise[Source[ClientConnector.ForwardConnectCommand, NotUsed]]()
clientConnector ! ClientConnector.ConnectReceivedLocally(connectionId, cp, carry, reply)
Source.futureSource(
reply.future.map(_.map {
case ClientConnector.ForwardConnect => cp.encode(ByteString.newBuilder).result()
case ClientConnector.ForwardPingReq => pingReqBytes
case ClientConnector.ForwardPublish(publish, packetId) =>
publish.encode(ByteString.newBuilder, packetId).result()
case ClientConnector.ForwardPubRel(packetId) =>
PubRel(packetId).encode(ByteString.newBuilder).result()
}.mapError {
case ClientConnector.ConnectFailed => ActorMqttClientSession.ConnectFailed
case Subscriber.SubscribeFailed => ActorMqttClientSession.SubscribeFailed
case ClientConnector.PingFailed => ActorMqttClientSession.PingFailed
}
.watchTermination() { (_, done) =>
done.onComplete {
case Success(_) => killSwitch.shutdown()
case Failure(t) => killSwitch.abort(t)
}
}))
case Command(cp: PubAck, completed, _) =>
val reply = Promise[Consumer.ForwardPubAck.type]()
consumerPacketRouter ! RemotePacketRouter.Route(None,
cp.packetId,
Consumer.PubAckReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: PubRec, completed, _) =>
val reply = Promise[Consumer.ForwardPubRec.type]()
consumerPacketRouter ! RemotePacketRouter.Route(None,
cp.packetId,
Consumer.PubRecReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: PubComp, completed, _) =>
val reply = Promise[Consumer.ForwardPubComp.type]()
consumerPacketRouter ! RemotePacketRouter.Route(None,
cp.packetId,
Consumer.PubCompReceivedLocally(reply),
reply)
reply.future.onComplete { result =>
completed
.foreach(_.complete(result.map(_ => Done)))
}
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result())).recover {
case _: RemotePacketRouter.CannotRoute => ByteString.empty
}
case Command(cp: Subscribe, _, carry) =>
val reply = Promise[Subscriber.ForwardSubscribe]()
clientConnector ! ClientConnector.SubscribeReceivedLocally(connectionId, cp, carry, reply)
Source.future(
reply.future.map(command => cp.encode(ByteString.newBuilder, command.packetId).result()))
case Command(cp: Unsubscribe, _, carry) =>
val reply = Promise[Unsubscriber.ForwardUnsubscribe]()
clientConnector ! ClientConnector.UnsubscribeReceivedLocally(connectionId, cp, carry, reply)
Source.future(
reply.future.map(command => cp.encode(ByteString.newBuilder, command.packetId).result()))
case Command(cp: Disconnect.type, _, _) =>
val reply = Promise[ClientConnector.ForwardDisconnect.type]()
clientConnector ! ClientConnector.DisconnectReceivedLocally(connectionId, reply)
Source.future(reply.future.map(_ => cp.encode(ByteString.newBuilder).result()))
case c: Command[A] => throw new IllegalStateException(s"$c is not a client command")
})
.recover {
case _: WatchedActorTerminatedException => ByteString.empty
}
.filter(_.nonEmpty)
.log("client-commandFlow", _.iterator.decodeControlPacket(settings.maxPacketSize)) // we decode here so we can see the generated packet id
.withAttributes(ActorAttributes.logLevels(onFailure = Logging.DebugLevel)))
}