src/Microsoft.Azure.Relay/RelayedHttpListenerContext.cs (108 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
namespace Microsoft.Azure.Relay
{
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
/// <summary>
/// Provides access to the request and response objects representing a client request to a <see cref="HybridConnectionListener"/>.
/// This is modeled after System.Net.HttpListenerContext.
/// </summary>
public class RelayedHttpListenerContext : ITraceSource
{
static readonly TimeSpan AcceptTimeout = TimeSpan.FromSeconds(20);
string cachedToString;
internal RelayedHttpListenerContext(HybridConnectionListener listener, Uri requestUri, string trackingId, string method, IDictionary<string, string> requestHeaders)
{
this.Listener = listener;
this.TrackingContext = TrackingContext.Create(trackingId, requestUri);
this.Request = new RelayedHttpListenerRequest(requestUri, method, requestHeaders);
this.Response = new RelayedHttpListenerResponse(this);
this.FlowSubProtocol();
}
/// <summary>
/// Gets the <see cref="RelayedHttpListenerRequest"/> that represents a client's request for a resource.
/// </summary>
public RelayedHttpListenerRequest Request { get; }
/// <summary>
/// Gets the <see cref="RelayedHttpListenerResponse"/> object to control the response to the client's request.
/// </summary>
public RelayedHttpListenerResponse Response { get; }
/// <summary>
/// Gets the TrackingContext for this listener context.
/// </summary>
public TrackingContext TrackingContext { get; }
internal HybridConnectionListener Listener { get; }
/// <summary>
/// Returns a string that represents the current object. Includes a TrackingId for end to end correlation.
/// </summary>
public override string ToString()
{
return this.cachedToString ?? (this.cachedToString = nameof(RelayedHttpListenerContext) + "(" + this.TrackingContext + ")");
}
internal async Task<WebSocketStream> AcceptAsync(Uri rendezvousUri)
{
// Performance: Address Resolution (ARP) work-around: When we receive the control message from a TCP connection which hasn't had any
// outbound traffic for 2 minutes the ARP cache no longer has the MAC address required to ACK the control message. If we also begin
// connecting a new socket at exactly the same time there's a known race condition (insert link here) where ARP can only resolve one
// address at a time, which causes the loser of the race to have to retry after 3000ms. To avoid the 3000ms delay we just pause for
// a few ms here instead.
await Task.Delay(TimeSpan.FromMilliseconds(2)).ConfigureAwait(false);
var clientWebSocket = this.CreateWebSocket();
// If we are accepting a sub-protocol handle that here
var subProtocol = this.Response.Headers[HybridConnectionConstants.Headers.SecWebSocketProtocol];
if (!string.IsNullOrEmpty(subProtocol))
{
clientWebSocket.Options.AddSubProtocol(subProtocol);
}
using (var cancelSource = new CancellationTokenSource(AcceptTimeout))
{
await clientWebSocket.ConnectAsync(rendezvousUri, cancelSource.Token).ConfigureAwait(false);
}
var webSocketStream = new WebSocketStream(clientWebSocket.WebSocket, this.TrackingContext);
return webSocketStream;
}
internal async Task RejectAsync(Uri rendezvousUri)
{
IClientWebSocket clientWebSocket = null;
try
{
if (this.Response.StatusCode == HttpStatusCode.Continue)
{
this.Response.StatusCode = HttpStatusCode.BadRequest;
this.Response.StatusDescription = "Rejected by user code";
}
// Add the status code/description to the URI query string
int requiredCapacity = rendezvousUri.OriginalString.Length + 50 + this.Response.StatusDescription.Length;
var stringBuilder = new StringBuilder(rendezvousUri.OriginalString, requiredCapacity);
stringBuilder.AppendFormat("&{0}={1}", HybridConnectionConstants.StatusCode, (int)this.Response.StatusCode);
stringBuilder.AppendFormat("&{0}={1}", HybridConnectionConstants.StatusDescription, WebUtility.UrlEncode(this.Response.StatusDescription));
Uri rejectUri = new Uri(stringBuilder.ToString());
clientWebSocket = this.CreateWebSocket();
using (var cancelSource = new CancellationTokenSource(AcceptTimeout))
{
await clientWebSocket.ConnectAsync(rejectUri, cancelSource.Token).ConfigureAwait(false);
}
}
catch (Exception e) when (!Fx.IsFatal(e))
{
WebException webException;
HttpWebResponse httpWebResponse;
if (e is WebSocketException &&
(webException = e.InnerException as WebException) != null &&
(httpWebResponse = webException.Response as HttpWebResponse) != null &&
httpWebResponse.StatusCode == HttpStatusCode.Gone)
{
// status code of "Gone" is expected when rejecting a client request
return;
}
RelayEventSource.Log.HandledExceptionAsWarning(this, e);
}
finally
{
clientWebSocket?.WebSocket?.Abort();
}
}
IClientWebSocket CreateWebSocket()
{
var clientWebSocket = this.Listener.ClientWebSocketFactory.Create();
clientWebSocket.Options.SetBuffer(this.Listener.ConnectionBufferSize, this.Listener.ConnectionBufferSize);
DefaultWebProxy.ConfigureProxy(clientWebSocket.Options, this.Listener.Proxy);
clientWebSocket.Options.KeepAliveInterval = this.Listener.KeepAliveInterval;
return clientWebSocket;
}
void FlowSubProtocol()
{
// By default use the first sub-protocol (if present)
string subProtocol = this.Request.Headers[HybridConnectionConstants.Headers.SecWebSocketProtocol];
if (!string.IsNullOrEmpty(subProtocol))
{
int separatorIndex = subProtocol.IndexOf(',');
if (separatorIndex >= 0)
{
// more than one sub-protocol in headers, only use the first.
subProtocol = subProtocol.Substring(0, separatorIndex);
}
this.Response.Headers[HybridConnectionConstants.Headers.SecWebSocketProtocol] = subProtocol;
}
}
}
}