in linkis-spring-cloud-services/linkis-service-gateway/linkis-spring-cloud-gateway/src/main/java/org/apache/linkis/gateway/springcloud/websocket/SpringCloudGatewayWebsocketFilter.java [82:253]
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
changeSchemeIfIsWebSocketUpgrade(websocketRoutingFilter, exchange);
URI requestUrl = exchange.getRequiredAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
String scheme = requestUrl.getScheme();
if (!ServerWebExchangeUtils.isAlreadyRouted(exchange)
&& ("ws".equals(scheme) || "wss".equals(scheme))) {
ServerWebExchangeUtils.setAlreadyRouted(exchange);
HttpHeaders headers = exchange.getRequest().getHeaders();
List<String> protocols = headers.get("Sec-WebSocket-Protocol");
if (protocols != null) {
protocols =
(List<String>)
protocols.stream()
.flatMap(
(header) -> {
return Arrays.stream(StringUtils.commaDelimitedListToStringArray(header));
})
.map(String::trim)
.collect(Collectors.toList());
}
List<String> collectedProtocols = protocols;
GatewayContext gatewayContext = getGatewayContext(exchange);
return this.webSocketService.handleRequest(
exchange,
new WebSocketHandler() {
public Mono<Void> handle(WebSocketSession webClientSocketSession) {
GatewayWebSocketSessionConnection gatewayWebSocketSession =
getGatewayWebSocketSessionConnection(
GatewaySSOUtils.getLoginUsername(gatewayContext), webClientSocketSession);
FluxSinkListener fluxSinkListener =
new FluxSinkListener<WebSocketMessage>() {
private FluxSink<WebSocketMessage> fluxSink = null;
@Override
public void setFluxSink(FluxSink<WebSocketMessage> fluxSink) {
this.fluxSink = fluxSink;
}
@Override
public void next(WebSocketMessage webSocketMessage) {
if (fluxSink != null) fluxSink.next(webSocketMessage);
GatewaySSOUtils.updateLastAccessTime(gatewayContext);
}
@Override
public void complete() {
if (fluxSink != null) fluxSink.complete();
}
};
Flux<WebSocketMessage> receives =
Flux.create(
sink -> {
fluxSinkListener.setFluxSink(sink);
});
gatewayWebSocketSession
.receive()
.doOnNext(WebSocketMessage::retain)
.map(
t -> {
String user;
try {
user = GatewaySSOUtils.getLoginUsername(gatewayContext);
} catch (Throwable e) {
if (gatewayWebSocketSession.isAlive()) {
String message =
Message.response(
Message.noLogin(e.getMessage())
.$less$less(gatewayContext.getRequest().getRequestURI()));
;
fluxSinkListener.next(
getWebSocketMessage(
gatewayWebSocketSession.bufferFactory(), message));
}
return gatewayWebSocketSession.close();
}
if (t.getType() == WebSocketMessage.Type.PING
|| t.getType() == WebSocketMessage.Type.PONG) {
WebSocketMessage pingMsg =
new WebSocketMessage(WebSocketMessage.Type.PING, t.getPayload());
gatewayWebSocketSession.heartbeat(pingMsg);
return sendMsg(exchange, gatewayWebSocketSession, pingMsg);
}
String json = t.getPayloadAsText();
t.release();
ServerEvent serverEvent = SocketServerEvent.getServerEvent(json);
((SpringCloudGatewayHttpRequest) gatewayContext.getRequest())
.setRequestBody(SocketServerEvent.getMessageData(serverEvent));
((SpringCloudGatewayHttpRequest) gatewayContext.getRequest())
.setRequestURI(serverEvent.getMethod());
parser.parse(gatewayContext);
if (gatewayContext.getResponse().isCommitted()) {
return sendMsg(
exchange,
gatewayWebSocketSession,
((WebsocketGatewayHttpResponse) gatewayContext.getResponse())
.getWebSocketMsg());
}
ServiceInstance serviceInstance = router.route(gatewayContext);
if (gatewayContext.getResponse().isCommitted()) {
return sendMsg(
exchange,
gatewayWebSocketSession,
((WebsocketGatewayHttpResponse) gatewayContext.getResponse())
.getWebSocketMsg());
}
WebSocketSession webSocketProxySession =
getProxyWebSocketSession(gatewayWebSocketSession, serviceInstance);
if (webSocketProxySession != null) {
return sendMsg(exchange, webSocketProxySession, json);
} else {
URI uri = exchange.getRequest().getURI();
Boolean encoded = ServerWebExchangeUtils.containsEncodedParts(uri);
String host;
int port;
if (StringUtils.isEmpty(serviceInstance.getInstance())) {
org.springframework.cloud.client.ServiceInstance service =
loadBalancer.choose(serviceInstance.getApplicationName());
host = service.getHost();
port = service.getPort();
} else {
String[] instanceInfo = serviceInstance.getInstance().split(":");
host = instanceInfo[0];
port = Integer.parseInt(instanceInfo[1]);
}
URI requestURI =
UriComponentsBuilder.fromUri(requestUrl)
.host(host)
.port(port)
.build(encoded)
.toUri();
HttpHeaders filtered =
HttpHeadersFilter.filterRequest(
getHeadersFilters(websocketRoutingFilter), exchange);
SpringCloudHttpUtils.addIgnoreTimeoutSignal(filtered);
return webSocketClient.execute(
requestURI,
filtered,
new WebSocketHandler() {
public Mono<Void> handle(WebSocketSession proxySession) {
setProxyWebSocketSession(
user, serviceInstance, gatewayWebSocketSession, proxySession);
Mono<Void> proxySessionSend =
sendMsg(exchange, proxySession, json);
proxySessionSend.subscribe();
return getProxyWebSocketSession(
gatewayWebSocketSession, serviceInstance)
.receive()
.doOnNext(WebSocketMessage::retain)
.doOnNext(fluxSinkListener::next)
.then();
}
public List<String> getSubProtocols() {
return collectedProtocols;
}
});
}
})
.doOnComplete(fluxSinkListener::complete)
.doOnNext(Mono::subscribe)
.subscribe();
return gatewayWebSocketSession.send(receives);
}
public List<String> getSubProtocols() {
return collectedProtocols;
}
});
} else {
return chain.filter(exchange);
}
}