src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs (312 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
#if !NETCOREAPP
using System.Diagnostics;
#endif
using System.IO.Pipelines;
using System.Net.WebSockets;
#if !NETCOREAPP
using System.Runtime.InteropServices;
#endif
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.Logging;
namespace Microsoft.Azure.SignalR.Connections.Client.Internal;
/// <summary>
/// Copied from aspnetcore repo, TODO: refactor
/// </summary>
internal partial class WebSocketsTransport : IDuplexPipe
{
public static PipeOptions DefaultOptions = new PipeOptions(writerScheduler: PipeScheduler.ThreadPool, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false, pauseWriterThreshold: 0, resumeWriterThreshold: 0);
private readonly WebSocketMessageType _webSocketMessageType = WebSocketMessageType.Binary;
private readonly ClientWebSocket _webSocket;
private readonly IAccessTokenProvider _accessTokenProvider;
private readonly ILogger _logger;
private readonly TimeSpan _closeTimeout;
private volatile bool _aborted;
private IDuplexPipe _application;
private IDuplexPipe _transport;
public PipeReader Input => _transport.Input;
public PipeWriter Output => _transport.Output;
internal Task Running { get; private set; } = Task.CompletedTask;
public WebSocketsTransport(WebSocketConnectionOptions connectionOptions,
ILoggerFactory loggerFactory,
IAccessTokenProvider accessTokenProvider)
{
_logger = (loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory))).CreateLogger<WebSocketsTransport>();
_webSocket = new ClientWebSocket();
// Issue in ClientWebSocket prevents user-agent being set - https://github.com/dotnet/corefx/issues/26627
//_webSocket.Options.SetRequestHeader("User-Agent", Constants.UserAgentHeader.ToString());
if (connectionOptions != null)
{
if (connectionOptions.Headers != null)
{
foreach (var header in connectionOptions.Headers)
{
_webSocket.Options.SetRequestHeader(header.Key, header.Value);
}
}
if (connectionOptions.Cookies != null)
{
_webSocket.Options.Cookies = connectionOptions.Cookies;
}
if (connectionOptions.ClientCertificates != null)
{
_webSocket.Options.ClientCertificates.AddRange(connectionOptions.ClientCertificates);
}
if (connectionOptions.Credentials != null)
{
_webSocket.Options.Credentials = connectionOptions.Credentials;
}
if (connectionOptions.Proxy != null)
{
_webSocket.Options.Proxy = connectionOptions.Proxy;
}
if (connectionOptions.UseDefaultCredentials != null)
{
_webSocket.Options.UseDefaultCredentials = connectionOptions.UseDefaultCredentials.Value;
}
connectionOptions.WebSocketConfiguration?.Invoke(_webSocket.Options);
_closeTimeout = connectionOptions.CloseTimeout;
}
// Set this header so the server auth middleware will set an Unauthorized instead of Redirect status code
// See: https://github.com/aspnet/Security/blob/ff9f145a8e89c9756ea12ff10c6d47f2f7eb345f/src/Microsoft.AspNetCore.Authentication.Cookies/Events/CookieAuthenticationEvents.cs#L42
_webSocket.Options.SetRequestHeader("X-Requested-With", "XMLHttpRequest");
// Ignore the HttpConnectionOptions access token provider. We were given an updated delegate from the HttpConnection.
_accessTokenProvider = accessTokenProvider;
}
public async Task StartAsync(Uri url, CancellationToken cancellationToken = default)
{
#if NET6_0_OR_GREATER
ArgumentNullException.ThrowIfNull(url);
#else
if (url == null)
{
throw new ArgumentNullException(nameof(url));
}
#endif
var resolvedUrl = ResolveWebSocketsUrl(url);
string accessToken = null;
// We don't need to capture to a local because we never change this delegate.
if (_accessTokenProvider != null)
{
accessToken = await _accessTokenProvider.ProvideAsync();
if (!string.IsNullOrEmpty(accessToken))
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
}
}
Log.StartTransport(_logger, _webSocketMessageType, resolvedUrl);
try
{
await _webSocket.ConnectAsync(resolvedUrl, cancellationToken);
}
catch (Exception e)
{
_webSocket.Dispose();
throw e.WrapAsAzureSignalRException(accessToken);
}
Log.StartedTransport(_logger);
// Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa)
var options = DefaultOptions;
var pair = DuplexPipe.CreateConnectionPair(options, options);
_transport = pair.Transport;
_application = pair.Application;
// TODO: Handle TCP connection errors
// https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251
Running = ProcessSocketAsync(_webSocket);
}
public async Task StopAsync()
{
Log.TransportStopping(_logger);
if (_application == null)
{
// We never started
return;
}
_transport.Output.Complete();
_transport.Input.Complete();
// Cancel any pending reads from the application, this should start the entire shutdown process
_application.Input.CancelPendingRead();
try
{
await Running;
}
catch (Exception ex)
{
Log.TransportStopped(_logger, ex);
// exceptions have been handled in the Running task continuation by closing the channel with the exception
return;
}
finally
{
_webSocket.Dispose();
}
Log.TransportStopped(_logger, null);
}
private static bool WebSocketCanSend(ClientWebSocket ws)
{
return !(ws.State == WebSocketState.Aborted ||
ws.State == WebSocketState.Closed ||
ws.State == WebSocketState.CloseSent);
}
private static Uri ResolveWebSocketsUrl(Uri url)
{
var uriBuilder = new UriBuilder(url);
if (url.Scheme == "http")
{
uriBuilder.Scheme = "ws";
}
else if (url.Scheme == "https")
{
uriBuilder.Scheme = "wss";
}
return uriBuilder.Uri;
}
private async Task ProcessSocketAsync(ClientWebSocket socket)
{
using (socket)
{
// Begin sending and receiving.
var receiving = StartReceiving(socket);
var sending = StartSending(socket);
// Wait for send or receive to complete
var trigger = await Task.WhenAny(receiving, sending);
if (trigger == receiving)
{
// We're waiting for the application to finish and there are 2 things it could be doing
// 1. Waiting for application data
// 2. Waiting for a websocket send to complete
// Cancel the application so that ReadAsync yields
_application.Input.CancelPendingRead();
using var delayCts = new CancellationTokenSource();
var resultTask = await Task.WhenAny(sending, Task.Delay(_closeTimeout, delayCts.Token));
if (resultTask != sending)
{
_aborted = true;
// Abort the websocket if we're stuck in a pending send to the client
socket.Abort();
}
else
{
// Cancel the timeout
delayCts.Cancel();
}
}
else
{
// We're waiting on the websocket to close and there are 2 things it could be doing
// 1. Waiting for websocket data
// 2. Waiting on a flush to complete (backpressure being applied)
_aborted = true;
// Abort the websocket if we're stuck in a pending receive from the client
socket.Abort();
// Cancel any pending flush so that we can quit
_application.Output.CancelPendingFlush();
}
}
}
private async Task StartReceiving(ClientWebSocket socket)
{
try
{
while (true)
{
#if NETCOREAPP
// Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read
var result = await socket.ReceiveAsync(Memory<byte>.Empty, CancellationToken.None);
if (result.MessageType == WebSocketMessageType.Close)
{
Log.WebSocketClosed(_logger, _webSocket.CloseStatus);
if (_webSocket.CloseStatus != WebSocketCloseStatus.NormalClosure)
{
throw new InvalidOperationException($"Websocket closed with error: {_webSocket.CloseStatus}.");
}
return;
}
#endif
var memory = _application.Output.GetMemory(2048);
#if NETCOREAPP
// Because we checked the CloseStatus from the 0 byte read above, we don't need to check again after reading
var receiveResult = await socket.ReceiveAsync(memory, CancellationToken.None);
#else
var isArray = MemoryMarshal.TryGetArray<byte>(memory, out var arraySegment);
Debug.Assert(isArray);
// Exceptions are handled above where the send and receive tasks are being run.
var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None);
#endif
// Need to check again for netstandard2.1 because a close can happen between a 0-byte read and the actual read
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
Log.WebSocketClosed(_logger, _webSocket.CloseStatus);
if (_webSocket.CloseStatus != WebSocketCloseStatus.NormalClosure)
{
throw new InvalidOperationException($"Websocket closed with error: {_webSocket.CloseStatus}.");
}
return;
}
Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
_application.Output.Advance(receiveResult.Count);
var flushResult = await _application.Output.FlushAsync();
// We canceled in the middle of applying back pressure
// or if the consumer is done
if (flushResult.IsCanceled || flushResult.IsCompleted)
{
break;
}
}
}
catch (OperationCanceledException)
{
Log.ReceiveCanceled(_logger);
}
catch (Exception ex)
{
if (!_aborted)
{
_application.Output.Complete(ex);
}
}
finally
{
// We're done writing
_application.Output.Complete();
Log.ReceiveStopped(_logger);
}
}
private async Task StartSending(ClientWebSocket socket)
{
Exception error = null;
try
{
while (true)
{
var result = await _application.Input.ReadAsync();
var buffer = result.Buffer;
// Get a frame from the application
try
{
if (result.IsCanceled)
{
break;
}
if (!buffer.IsEmpty)
{
try
{
Log.ReceivedFromApp(_logger, buffer.Length);
if (WebSocketCanSend(socket))
{
await socket.SendAsync(buffer, _webSocketMessageType);
}
else
{
break;
}
}
catch (Exception ex)
{
if (!_aborted)
{
Log.ErrorSendingMessage(_logger, ex);
}
break;
}
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
_application.Input.AdvanceTo(buffer.End);
}
}
}
catch (Exception ex)
{
error = ex;
}
finally
{
if (WebSocketCanSend(socket))
{
try
{
// We're done sending, send the close frame to the client if the websocket is still open
await socket.CloseOutputAsync(error != null ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
catch (Exception ex)
{
Log.ClosingWebSocketFailed(_logger, ex);
}
}
_application.Input.Complete();
Log.SendStopped(_logger);
}
}
}