in http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/WebSocketClientBlueprint.scala [50:180]
def apply(
request: WebSocketRequest,
settings: ClientConnectionSettings,
log: LoggingAdapter): Http.WebSocketClientLayer =
LogByteStringTools.logTLSBidiBySetting("client-plain-text", settings.logUnencryptedNetworkBytes).reversed
.atop(simpleTls)
.atopMat(handshake(request, settings, log))(Keep.right)
.atop(WebSocket.framing)
.atop(WebSocket.stack(serverSide = false, settings.websocketSettings, log = log))
.reversed
/**
* A bidi flow that injects and inspects the WS handshake and then goes out of the way. This BidiFlow
* can only be materialized once.
*/
def handshake(
request: WebSocketRequest,
settings: ClientConnectionSettings,
log: LoggingAdapter)
: BidiFlow[ByteString, ByteString, ByteString, ByteString, Future[WebSocketUpgradeResponse]] = {
import request._
val result = Promise[WebSocketUpgradeResponse]()
val valve = StreamUtils.OneTimeValve()
val subprotocols: immutable.Seq[String] = subprotocol.toList.flatMap(_.split(",")).map(_.trim)
val (initialRequest, key) =
Handshake.Client.buildRequest(uri, extraHeaders, subprotocols, settings.websocketRandomFactory())
val hostHeader = Host(uri.authority.normalizedFor(uri.scheme))
val renderedInitialRequest =
HttpRequestRendererFactory.renderStrict(RequestRenderingContext(initialRequest, hostHeader), settings, log)
class UpgradeStage extends SimpleLinearGraphStage[ByteString] {
override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
// a special version of the parser which only parses one message and then reports the remaining data
// if some is available
val parser: HttpResponseParser =
new HttpResponseParser(settings.parserSettings, HttpHeaderParser(settings.parserSettings, log)) {
var first = true
override def handleInformationalResponses = false
override protected def parseMessage(input: ByteString, offset: Int): StateResult = {
if (first) {
try {
// If we're called recursively then that's a next message
first = false
super.parseMessage(input, offset)
} catch {
// Specifically NotEnoughDataException, but that's not visible here
case t: SingletonException => {
// If parsing the first message fails, retry and treat it like the first message again.
first = true
throw t
}
}
} else {
emit(RemainingBytes(input.drop(offset)))
terminate()
}
}
}
parser.setContextForNextResponse(HttpResponseParser.ResponseContext(HttpMethods.GET, None))
override def onPush(): Unit = {
parser.parseBytes(grab(in)) match {
case NeedMoreData => pull(in)
case ResponseStart(status, protocol, attributes, headers, entity, close) =>
val response = new HttpResponse(status, headers, attributes, HttpEntity.Empty, protocol)
Handshake.Client.validateResponse(response, subprotocols, key) match {
case Right(NegotiatedWebSocketSettings(protocol)) =>
result.success(ValidUpgrade(response, protocol))
setHandler(in,
new InHandler {
override def onPush(): Unit = push(out, grab(in))
})
valve.open()
val parseResult = parser.onPull()
require(parseResult == ParserOutput.MessageEnd,
s"parseResult should be MessageEnd but was $parseResult")
parser.onPull() match {
case NeedMoreData => pull(in)
case RemainingBytes(bytes) => push(out, bytes)
case other =>
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
case Left(problem) =>
result.success(InvalidUpgradeResponse(response, s"WebSocket server at $uri returned $problem"))
failStage(new IllegalArgumentException(s"WebSocket upgrade did not finish because of '$problem'"))
}
case MessageStartError(statusCode, errorInfo) =>
throw new IllegalStateException(s"Message failed with status code $statusCode; Error info: $errorInfo")
case other =>
throw new IllegalStateException(s"unexpected element of type ${other.getClass}")
}
}
override def onPull(): Unit = pull(in)
setHandlers(in, out, this)
override def onUpstreamFailure(ex: Throwable): Unit = {
result.tryFailure(new RuntimeException("Connection failed.", ex))
super.onUpstreamFailure(ex)
}
}
override def toString = "UpgradeStage"
}
BidiFlow.fromGraph(GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._
val networkIn = b.add(Flow[ByteString].via(new UpgradeStage))
val wsIn = b.add(Flow[ByteString])
val handshakeRequestSource = b.add(Source.single(renderedInitialRequest) ++ valve.source)
val httpRequestBytesAndThenWSBytes = b.add(Concat[ByteString]())
handshakeRequestSource ~> httpRequestBytesAndThenWSBytes
wsIn.outlet ~> httpRequestBytesAndThenWSBytes
BidiShape(
networkIn.in,
networkIn.out,
wsIn.in,
httpRequestBytesAndThenWSBytes.out)
}).mapMaterializedValue(_ => result.future)
}