src/Transport/WebSocketTransport.cs (162 lines of code) (raw):

// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. namespace Microsoft.Azure.Amqp.Transport { using System; using System.Net; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; /// <summary> /// Defines the web socket transport. /// </summary> public class WebSocketTransport : TransportBase { readonly WebSocket webSocket; readonly Uri uri; readonly EndPoint local; readonly EndPoint remote; ITransportMonitor usageMeter; /// <inheritdoc cref="TransportBase.LocalEndPoint"/> public override EndPoint LocalEndPoint => this.local; /// <inheritdoc cref="TransportBase.RemoteEndPoint"/> public override EndPoint RemoteEndPoint => this.remote; internal WebSocketTransport(WebSocket webSocket, Uri uri, EndPoint local, EndPoint remote) : base(WebSocketTransportSettings.WebSockets) { this.webSocket = webSocket; this.uri = uri; this.local = local; this.remote = remote; } /// <summary> /// Starts a write operation. /// </summary> /// <param name="args">The write arguments.</param> /// <returns>true if the write operation is pending, otherwise false.</returns> public sealed override bool WriteAsync(TransportAsyncCallbackArgs args) { ArraySegment<byte> buffer; ByteBuffer mergedBuffer = null; DateTime startTime = DateTime.UtcNow; if (args.Buffer != null) { buffer = new ArraySegment<byte>(args.Buffer, args.Offset, args.Count); } else { Fx.Assert(args.ByteBufferList != null, "Buffer list should not be null when buffer is null"); if (args.ByteBufferList.Count == 1) { ByteBuffer byteBuffer = args.ByteBufferList[0]; buffer = new ArraySegment<byte>(byteBuffer.Buffer, byteBuffer.Offset, byteBuffer.Length); } else { // Copy all buffers into one big buffer to avoid SSL overhead mergedBuffer = new ByteBuffer(args.Count, false, false); for (int i = 0; i < args.ByteBufferList.Count; ++i) { ByteBuffer byteBuffer = args.ByteBufferList[i]; Buffer.BlockCopy(byteBuffer.Buffer, byteBuffer.Offset, mergedBuffer.Buffer, mergedBuffer.Length, byteBuffer.Length); mergedBuffer.Append(byteBuffer.Length); } buffer = new ArraySegment<byte>(mergedBuffer.Buffer, 0, mergedBuffer.Length); } } Task task = this.webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, CancellationToken.None); if (task.IsCompleted) { this.OnWriteComplete(args, buffer, mergedBuffer, startTime); return false; } task.ContinueWith(static (t, s) => { var (transport, args, buffer, mergedBuffer, startTime) = (Tuple<WebSocketTransport, TransportAsyncCallbackArgs, ArraySegment<byte>, ByteBuffer, DateTime>) s; if (t.IsFaulted) { args.Exception = t.Exception?.InnerException; } else if (t.IsCanceled) { args.Exception = new OperationCanceledException(); } else { transport.OnWriteComplete(args, buffer, mergedBuffer, startTime); } args.CompletedCallback(args); }, Tuple.Create(this, args, buffer, mergedBuffer, startTime)); return true; } /// <summary> /// Starts a read operation. /// </summary> /// <param name="args">The read arguments.</param> /// <returns>true if the read operation is pending, otherwise false.</returns> public sealed override bool ReadAsync(TransportAsyncCallbackArgs args) { DateTime startTime = DateTime.UtcNow; ArraySegment<byte> buffer = new ArraySegment<byte>(args.Buffer, args.Offset, args.Count); Task<WebSocketReceiveResult> task = this.webSocket.ReceiveAsync(buffer, CancellationToken.None); if (task.IsCompleted) { this.OnReadComplete(args, task.Result.Count, startTime); return false; } task.ContinueWith(static (t, s) => { var (transport, args, startTime) = (Tuple<WebSocketTransport, TransportAsyncCallbackArgs, DateTime>) s; if (t.IsFaulted) { args.Exception = t.Exception?.InnerException; } else if (t.IsCanceled) { args.Exception = new OperationCanceledException(); } else { transport.OnReadComplete(args, t.Result.Count, startTime); } args.CompletedCallback(args); }, Tuple.Create(this, args, startTime)); return true; } /// <summary> /// Opens the object. /// </summary> /// <returns>true if open is completed, otherwise false.</returns> protected override bool OpenInternal() { return true; } /// <summary> /// Closes the object. /// </summary> /// <returns>true if close is completed, otherwise false.</returns> protected override bool CloseInternal() { Task task = webSocket.CloseAsync(WebSocketCloseStatus.Empty, string.Empty, CancellationToken.None); if (task.IsCompleted) { return true; } task.ContinueWith(static (t, s) => { var thisPtr = (WebSocketTransport) s; var exception = t.IsFaulted ? t.Exception?.InnerException : t.IsCanceled ? new OperationCanceledException() : null; thisPtr.CompleteClose(false, exception); }, this); return false; } /// <summary> /// Aborts the object. /// </summary> protected override void AbortInternal() { this.webSocket.Dispose(); } /// <summary> /// Sets a transport monitor for transport I/O operations. /// </summary> /// <param name="usageMeter">The transport monitor.</param> public override void SetMonitor(ITransportMonitor usageMeter) { this.usageMeter = usageMeter; } internal static bool MatchScheme(string scheme) { return string.Equals(scheme, WebSocketTransportSettings.WebSockets, StringComparison.OrdinalIgnoreCase) || string.Equals(scheme, WebSocketTransportSettings.SecureWebSockets, StringComparison.OrdinalIgnoreCase); } void OnWriteComplete(TransportAsyncCallbackArgs args, ArraySegment<byte> buffer, ByteBuffer byteBuffer, DateTime startTime) { args.BytesTransfered = buffer.Count; if (byteBuffer != null) { byteBuffer.Dispose(); } if (this.usageMeter != null) { this.usageMeter.OnTransportWrite(0, buffer.Count, 0, DateTime.UtcNow.Subtract(startTime).Ticks); } } void OnReadComplete(TransportAsyncCallbackArgs args, int count, DateTime startTime) { args.BytesTransfered = count; if (this.usageMeter != null) { this.usageMeter.OnTransportRead(0, count, 0, DateTime.UtcNow.Subtract(startTime).Ticks); } } } }