def create()

in http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/FrameHandler.scala [37:200]


  def create(server: Boolean): Flow[FrameEventOrError, Output, NotUsed] =
    Flow[FrameEventOrError].via(new HandlerStage(server))

  private class HandlerStage(server: Boolean) extends GraphStage[FlowShape[FrameEventOrError, Output]] {

    val in = Inlet[FrameEventOrError](Logging.simpleName(this) + ".in")
    val out = Outlet[Output](Logging.simpleName(this) + ".out")
    override val shape = FlowShape(in, out)

    override def toString: String = s"HandlerStage(server=$server)"

    override def createLogic(attributes: Attributes): GraphStageLogic =
      new GraphStageLogic(shape) with OutHandler {
        setHandler(out, this)
        setHandler(in, IdleHandler)

        override def onPull(): Unit = pull(in)

        private object IdleHandler extends ControlFrameStartHandler {
          def setAndHandleFrameStartWith(newHandler: ControlFrameStartHandler, start: FrameStart): Unit = {
            setHandler(in, newHandler)
            newHandler.handleFrameStart(start)
          }

          override def handleRegularFrameStart(start: FrameStart): Unit = {
            (start.header.opcode, start.isFullMessage) match {
              case (Opcode.Binary, true)  => publishMessagePart(BinaryMessagePart(start.data, last = true))
              case (Opcode.Binary, false) => setAndHandleFrameStartWith(new BinaryMessageHandler, start)
              case (Opcode.Text, _)       => setAndHandleFrameStartWith(new TextMessageHandler, start)
              case x                      => pushProtocolError()
            }
          }
        }

        private class BinaryMessageHandler extends MessageHandler(Opcode.Binary) {
          override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
            BinaryMessagePart(data, last)
        }

        private class TextMessageHandler extends MessageHandler(Opcode.Text) {
          val decoder = Utf8Decoder.create()

          override def createMessagePart(data: ByteString, last: Boolean): MessageDataPart =
            TextMessagePart(decoder.decode(data, endOfInput = last).get, last)
        }

        private abstract class MessageHandler(expectedOpcode: Opcode) extends ControlFrameStartHandler {
          var expectFirstHeader = true
          var finSeen = false
          def createMessagePart(data: ByteString, last: Boolean): MessageDataPart

          override def handleRegularFrameStart(start: FrameStart): Unit = {
            if ((expectFirstHeader && start.header.opcode == expectedOpcode) // first opcode must be the expected
              || start.header.opcode == Opcode.Continuation) { // further ones continuations
              expectFirstHeader = false

              if (start.header.fin) finSeen = true
              publish(start)
            } else pushProtocolError()
          }

          override def handleFrameData(data: FrameData): Unit = publish(data)

          def publish(part: FrameEvent): Unit =
            try {
              publishMessagePart(createMessagePart(part.data, last = finSeen && part.lastPart))
            } catch {
              case NonFatal(e) => closeWithCode(Protocol.CloseCodes.InconsistentData)
            }
        }

        private trait ControlFrameStartHandler extends FrameHandler {
          def handleRegularFrameStart(start: FrameStart): Unit

          override def handleFrameStart(start: FrameStart): Unit = start.header match {
            case h: FrameHeader if h.mask.isDefined && !server => pushProtocolError()
            case h: FrameHeader if h.rsv1 || h.rsv2 || h.rsv3  => pushProtocolError()
            case FrameHeader(op, _, length, fin, _, _, _) if op.isControl && (length > 125 || !fin) =>
              pushProtocolError()
            case h: FrameHeader if h.opcode.isControl =>
              if (start.isFullMessage) handleControlFrame(h.opcode, start.data, this)
              else collectControlFrame(start, this)
            case _ => handleRegularFrameStart(start)
          }

          override def handleFrameData(data: FrameData): Unit =
            throw new IllegalStateException("Expected FrameStart")
        }

        private class ControlFrameDataHandler(
            opcode: Opcode, _data: ByteString, nextHandler: InHandler) extends FrameHandler {
          var data = _data

          override def handleFrameData(data: FrameData): Unit = {
            this.data ++= data.data
            if (data.lastPart) handleControlFrame(opcode, this.data, nextHandler)
            else pull(in)
          }

          override def handleFrameStart(start: FrameStart): Unit =
            throw new IllegalStateException("Expected FrameData")
        }

        private trait FrameHandler extends InHandler {
          def handleFrameData(data: FrameData): Unit
          def handleFrameStart(start: FrameStart): Unit

          def handleControlFrame(opcode: Opcode, data: ByteString, nextHandler: InHandler): Unit = {
            setHandler(in, nextHandler)
            opcode match {
              case Opcode.Ping => publishDirectResponse(FrameEvent.fullFrame(Opcode.Pong, None, data, fin = true))
              case Opcode.Pong =>
                // ignore unsolicited Pong frame
                pull(in)
              case Opcode.Close =>
                setHandler(in, WaitForPeerTcpClose)
                push(out, PeerClosed.parse(data))
              case Opcode.Other(o) => closeWithCode(Protocol.CloseCodes.ProtocolError, "Unsupported opcode")
              case other => failStage(
                  new IllegalStateException(
                    s"unexpected message of type [${other.getClass.getName}] when expecting ControlFrame"))
            }
          }

          def pushProtocolError(): Unit = closeWithCode(Protocol.CloseCodes.ProtocolError)

          def closeWithCode(closeCode: Int, reason: String = ""): Unit = {
            setHandler(in, CloseAfterPeerClosed)
            push(out, ActivelyCloseWithCode(Some(closeCode), reason))
          }

          def collectControlFrame(start: FrameStart, nextHandler: InHandler): Unit = {
            require(!start.isFullMessage)
            setHandler(in, new ControlFrameDataHandler(start.header.opcode, start.data, nextHandler))
            pull(in)
          }

          def publishMessagePart(part: MessageDataPart): Unit =
            if (part.last) emitMultiple(out, Iterator(part, MessageEnd), () => setHandler(in, IdleHandler))
            else push(out, part)

          def publishDirectResponse(frame: FrameStart): Unit = push(out, DirectAnswer(frame))

          override def onPush(): Unit = grab(in) match {
            case data: FrameData   => handleFrameData(data)
            case start: FrameStart => handleFrameStart(start)
            case FrameError(ex)    => failStage(ex)
          }
        }

        private object CloseAfterPeerClosed extends InHandler {
          override def onPush(): Unit = grab(in) match {
            case FrameStart(FrameHeader(Opcode.Close, _, length, _, _, _, _), data) =>
              setHandler(in, WaitForPeerTcpClose)
              push(out, PeerClosed.parse(data))
            case _ => pull(in) // ignore all other data
          }
        }

        private object WaitForPeerTcpClose extends InHandler {
          override def onPush(): Unit = pull(in) // ignore
        }
      }
  }