sources/Google.Solutions.Testing.Apis/Net/InprocHttpProxy.cs (206 lines of code) (raw):

// // Copyright 2020 Google LLC // // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Http.Headers; using System.Net.Sockets; using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; namespace Google.Solutions.Testing.Apis.Net { /// <summary> /// Simple implementation of a HTTP proxy that can be used in tests. /// </summary> public class InProcessHttpProxy : IDisposable { private static readonly Regex ConnectRequestPattern = new Regex(@"^CONNECT ([a-zA-Z0-9\.*]+):(\d+) HTTP/1.1"); private static readonly Regex GetRequestPattern = new Regex(@"^GET (.*) HTTP/1.1"); // NB. Avoid reusing the same port twice in the same process. private static ushort nextProxyPort = 3128; private readonly CancellationTokenSource cancellation = new CancellationTokenSource(); private readonly LinkedList<string> connectionTargets = new LinkedList<string>(); private readonly TcpListener listener; private readonly Dictionary<string, string> staticFiles = new Dictionary<string, string>(); public IEnumerable<string> ConnectionTargets => this.connectionTargets; public ushort Port { get; } private void DispatchRequests() { while (!this.cancellation.IsCancellationRequested) { var socket = new NetworkStream(this.listener.AcceptSocket(), true); var _ = DispatchRequestAsync(socket).ConfigureAwait(false); } } private static string ReadLine(Stream stream) { var buffer = new StringBuilder(); while (true) { var b = stream.ReadByte(); if (b == -1 || b == (byte)'\n') { return buffer.ToString(); } else if (b == (byte)'\r') { } else { buffer.Append((char)b); } } } private async Task DispatchRequestAsync(NetworkStream clientStream) { using (clientStream) { var firstLine = ReadLine(clientStream); if (ConnectRequestPattern.Match(firstLine) is Match matchConnect && matchConnect.Success) { // // Read headers. // var headers = new Dictionary<string, string>(); string line; while (!string.IsNullOrEmpty((line = ReadLine(clientStream)))) { var parts = line.Split(':'); headers.Add(parts[0].ToLower(), parts[1].Trim()); } this.connectionTargets.AddLast(matchConnect.Groups[1].Value); await DispatchRequestAsync( matchConnect.Groups[1].Value, ushort.Parse(matchConnect.Groups[2].Value), headers, clientStream) .ConfigureAwait(true); } else if (GetRequestPattern.Match(firstLine) is Match getMatch && getMatch.Success && this.staticFiles.TryGetValue(getMatch.Groups[1].Value, out var responseBody)) { var response = Encoding.ASCII.GetBytes( "HTTP/1.1 200 OK\r\n" + $"Content-Length: {responseBody.Length}\r\n" + $"Content-Type: application/x-ns-proxy-autoconfig\r\n" + "\r\n" + responseBody); clientStream.Write(response, 0, response.Length); } else { var error = Encoding.ASCII.GetBytes($"HTTP /1.1 400 Bad Request"); clientStream.Write(error, 0, error.Length); } } } protected virtual async Task DispatchRequestAsync( string server, ushort serverPort, IDictionary<string, string> headers, NetworkStream clientStream) { // // Send response. // var response = Encoding.ASCII.GetBytes($"HTTP/1.1 200 OK\r\n\r\n"); clientStream.Write(response, 0, response.Length); clientStream.Flush(); // // Relay streams. // using (var client = new TcpClient(server, serverPort)) { var serverStream = client.GetStream(); await Task.WhenAll( clientStream.CopyToAsync(serverStream), serverStream.CopyToAsync(clientStream)) .ConfigureAwait(false); } } public InProcessHttpProxy(ushort port) { this.Port = port; this.listener = new TcpListener(new IPEndPoint(IPAddress.Loopback, port)); this.listener.Start(); _ = Task.Run(() => DispatchRequests()); } public InProcessHttpProxy() : this(nextProxyPort++) { } public void AddStaticFile(string path, string body) { this.staticFiles.Add(path, body); } public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } protected virtual void Dispose(bool disposing) { if (disposing) { this.cancellation.Cancel(); this.listener.Stop(); } } } public class InProcessAuthenticatingHttpProxy : InProcessHttpProxy { public string Realm { get; set; } = "default"; public NetworkCredential Credential { get; set; } public InProcessAuthenticatingHttpProxy( ushort port, NetworkCredential credential) : base(port) { this.Credential = credential; } public InProcessAuthenticatingHttpProxy( NetworkCredential credential) : base() { this.Credential = credential; } protected override async Task DispatchRequestAsync( string server, ushort serverPort, IDictionary<string, string> headers, NetworkStream clientStream) { if (headers.TryGetValue("proxy-authorization", out var proxyAuthHeader)) { var proxyAuth = AuthenticationHeaderValue.Parse(proxyAuthHeader); if (proxyAuth.Scheme.ToLower() != "basic") { SendUnauthenticatedError(clientStream); return; } var credentials = Encoding.ASCII.GetString( Convert.FromBase64String(proxyAuth.Parameter)).Split(':'); if (credentials.Length != 2 || credentials[0] != this.Credential.UserName || credentials[1] != this.Credential.Password) { SendUnauthenticatedError(clientStream); } else { await base.DispatchRequestAsync(server, serverPort, headers, clientStream) .ConfigureAwait(true); } } else { SendUnauthenticatedError(clientStream); } } private void SendUnauthenticatedError(NetworkStream clientStream) { var response = Encoding.ASCII.GetBytes( "HTTP/1.1 407 Proxy Authentication Required\r\n" + $"Proxy-Authenticate: Basic realm={this.Realm}\r\n" + "\r\n"); clientStream.Write(response, 0, response.Length); clientStream.Flush(); } } }