tools/test-proxy/Azure.Sdk.Tools.TestProxy/RecordingHandler.cs (1,022 lines of code) (raw):
using Azure.Core;
using Azure.Sdk.Tools.TestProxy.Common;
using Azure.Sdk.Tools.TestProxy.Common.Exceptions;
using Azure.Sdk.Tools.TestProxy.Store;
using Azure.Sdk.Tools.TestProxy.Transforms;
using Azure.Sdk.Tools.TestProxy.Vendored;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Primitives;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
namespace Azure.Sdk.Tools.TestProxy
{
public class RecordingHandler
{
#region constructor and common variables
public string ContextDirectory;
public bool HandleRedirects = true;
private const string SkipRecordingHeaderKey = "x-recording-skip";
private const string SkipRecordingRequestBody = "request-body";
private const string SkipRecordingRequestResponse = "request-response";
public IAssetsStore Store;
public StoreResolver Resolver;
private static readonly string[] s_excludedRequestHeaders = new string[] {
// Only applies to request between client and proxy
// TODO, we need to handle this properly, there are tests that actually test proxy functionality.
"Host",
"Proxy-Connection",
};
public HttpClient BaseRedirectableClient = Startup.Insecure ?
new HttpClient(new HttpClientHandler() { ServerCertificateCustomValidationCallback = (_, _, _, _) => true })
{
Timeout = TimeSpan.FromSeconds(600),
} :
new HttpClient()
{
Timeout = TimeSpan.FromSeconds(600)
};
public HttpClient BaseRedirectlessClient = Startup.Insecure ?
new HttpClient(new HttpClientHandler() { AllowAutoRedirect = false, ServerCertificateCustomValidationCallback = (_, _, _, _) => true })
{
Timeout = TimeSpan.FromSeconds(600),
} :
new HttpClient(new HttpClientHandler() { AllowAutoRedirect = false })
{
Timeout = TimeSpan.FromSeconds(600)
};
public HttpClient RedirectlessClient;
public HttpClient RedirectableClient;
public SanitizerDictionary SanitizerRegistry = new SanitizerDictionary();
public List<ResponseTransform> Transforms { get; set; }
public RecordMatcher Matcher { get; set; }
public readonly ConcurrentDictionary<string, ModifiableRecordSession> RecordingSessions
= new ConcurrentDictionary<string, ModifiableRecordSession>();
public readonly ConcurrentDictionary<string, ModifiableRecordSession> InMemorySessions
= new ConcurrentDictionary<string, ModifiableRecordSession>();
public readonly ConcurrentDictionary<string, ModifiableRecordSession> PlaybackSessions
= new ConcurrentDictionary<string, ModifiableRecordSession>();
public readonly ConcurrentDictionary<string, ConcurrentQueue<AuditLogItem>> AuditSessions
= new ConcurrentDictionary<string, ConcurrentQueue<AuditLogItem>>();
/// <summary>
/// This exists to grab any sessions that might be ongoing. I don't have any evidence that sessions are being left behind, but
/// we have to have this to be certain!
/// </summary>
/// <returns></returns>
public List<ConcurrentQueue<AuditLogItem>> RetrieveOngoingAuditLogs()
{
List<ConcurrentQueue<AuditLogItem>> results = new List<ConcurrentQueue<AuditLogItem>>();
if (PlaybackSessions.Keys.Any())
{
foreach(var value in PlaybackSessions.Values)
{
results.Add(value.AuditLog);
}
}
if (RecordingSessions.Keys.Any())
{
foreach (var value in RecordingSessions.Values)
{
results.Add(value.AuditLog);
}
}
if (InMemorySessions.Keys.Any())
{
foreach (var value in InMemorySessions.Values)
{
results.Add(value.AuditLog);
}
}
return results;
}
public RecordingHandler(string targetDirectory, IAssetsStore store = null, StoreResolver storeResolver = null)
{
ContextDirectory = targetDirectory;
SetDefaultExtensions().Wait();
Store = store;
if (store == null)
{
Store = new NullStore();
}
Resolver = storeResolver;
if (Resolver == null)
{
Resolver = new StoreResolver();
}
}
#endregion
#region recording functionality
public async Task StopRecording(string sessionId, IDictionary<string, string> variables = null, bool saveRecording = true)
{
var id = Guid.NewGuid().ToString();
DebugLogger.LogTrace($"RECORD STOP BEGIN {id}.");
if (!RecordingSessions.TryRemove(sessionId, out var recordingSession))
{
return;
}
recordingSession.AuditLog.Enqueue(new AuditLogItem(sessionId, $"Stopping recording for {sessionId}."));
if (!AuditSessions.TryAdd(sessionId, recordingSession.AuditLog))
{
DebugLogger.LogError($"Unable to save audit log for {sessionId}");
}
var sanitizers = await SanitizerRegistry.GetSanitizers(recordingSession);
await recordingSession.Session.Sanitize(sanitizers);
if (variables != null)
{
foreach (var kvp in variables)
{
recordingSession.Session.Variables[kvp.Key] = kvp.Value;
}
}
if (saveRecording)
{
if (String.IsNullOrEmpty(recordingSession.Path))
{
if (!InMemorySessions.TryAdd(sessionId, recordingSession))
{
throw new HttpException(HttpStatusCode.InternalServerError, $"Unexpectedly failed to add new in-memory session under id {sessionId}.");
}
}
else
{
// Create directories above file if they don't already exist
WriteToDisk(recordingSession);
}
}
DebugLogger.LogTrace($"RECORD STOP END {id}.");
}
public void WriteToDisk(ModifiableRecordSession recordingSession)
{
// Create directories above file if they don't already exist
var directory = Path.GetDirectoryName(recordingSession.Path);
if (!String.IsNullOrEmpty(directory))
{
Directory.CreateDirectory(directory);
}
using var stream = System.IO.File.Create(recordingSession.Path);
var options = new JsonWriterOptions { Indented = true, Encoder = RecordEntry.WriterOptions.Encoder };
var writer = new Utf8JsonWriter(stream, options);
recordingSession.Session.Serialize(writer);
writer.Flush();
stream.Write(Encoding.UTF8.GetBytes(Environment.NewLine));
}
/// <summary>
/// Entrypoint handling an an optional parameter assets.json. If present, a restore option either MUST run or MAY run depending on if we're running in playback or adding new recordings.
/// </summary>
/// <param name="assetsJson">The absolute path to the targeted assets.json.</param>
/// <param name="forceCheckout">If this is set to true, a restore MUST be run. Otherwise, we just need to ensure that the current assets Tag is selected.</param>
/// <returns></returns>
private async Task RestoreAssetsJson(string assetsJson = null, bool forceCheckout = false)
{
if (!string.IsNullOrWhiteSpace(assetsJson))
{
await this.Store.Restore(assetsJson);
}
}
public async Task StartRecordingAsync(string sessionId, HttpResponse outgoingResponse, string assetsJson = null)
{
var id = Guid.NewGuid().ToString();
DebugLogger.LogTrace($"RECORD START BEGIN {id}.");
var auditEntry = new AuditLogItem(id, $"Starting record for path {sessionId}, which will return recordingId {id}.");
await RestoreAssetsJson(assetsJson, false);
var session = new ModifiableRecordSession(new RecordSession(), SanitizerRegistry, id)
{
Path = !string.IsNullOrWhiteSpace(sessionId) ? (await GetRecordingPath(sessionId, assetsJson)) : String.Empty,
Client = null
};
session.AuditLog.Enqueue(auditEntry);
if (!RecordingSessions.TryAdd(id, session))
{
throw new HttpException(HttpStatusCode.InternalServerError, $"Unexpectedly failed to add new recording session under id {id}.");
}
DebugLogger.LogTrace($"RECORD START END {id}.");
outgoingResponse.Headers.Append("x-recording-id", id);
}
public async Task HandleRecordRequestAsync(string recordingId, HttpRequest incomingRequest, HttpResponse outgoingResponse)
{
if (!RecordingSessions.TryGetValue(recordingId, out var session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"There is no active recording session under id {recordingId}.");
}
RecordEntry noBodyEntry = RecordingHandler.CreateNoBodyRecordEntry(incomingRequest);
session.AuditLog.Enqueue(new AuditLogItem(recordingId, noBodyEntry.RequestUri, noBodyEntry.RequestMethod.ToString()));
var sanitizers = await SanitizerRegistry.GetSanitizers(session);
DebugLogger.LogRequestDetails(incomingRequest, sanitizers);
(RecordEntry entry, byte[] requestBody) = await CreateEntryAsync(incomingRequest).ConfigureAwait(false);
var upstreamRequest = CreateUpstreamRequest(incomingRequest, requestBody);
HttpResponseMessage upstreamResponse = null;
// The experience around Content-Length is a bit weird in .NET. We're using the .NET native HttpClient class to send our requests. This comes with
// some automagic.
//
// If an incoming request...
// ...has a Content-Length 0 header, and no body. We should send along the Content-Length: 0 header with the upstreamrequest.
// ...has no Content-Length header, and no body. We _should not_ send along the Content-Length: 0 header.
// ...has no Content-Length header, a 0 length body, but a TransferEncoding header with value "chunked". We _should_ allow any other Content headers to stick around.
//
// The .NET http client is a bit weird about attaching the Content-Length header though. If you HAVE the .Content property defined, a Content-Length
// header WILL be added. This is due to the fact that on send, the client considers a populated Client property as having a body, even if it's zero length.
if (incomingRequest.ContentLength == null)
{
if(!incomingRequest.Headers["Transfer-Encoding"].ToString().Split(' ').Select(x => x.Trim()).Contains("chunked"))
{
upstreamRequest.Content = null;
}
}
if (HandleRedirects)
{
upstreamResponse = await (session.Client ?? RedirectableClient).SendAsync(upstreamRequest).ConfigureAwait(false);
}
else
{
upstreamResponse = await (session.Client ?? RedirectlessClient).SendAsync(upstreamRequest).ConfigureAwait(false);
}
byte[] body = Array.Empty<byte>();
// HEAD requests do NOT have a body regardless of the value of the Content-Length header
if (incomingRequest.Method.ToUpperInvariant() != "HEAD")
{
body = CompressionUtilities.DecompressBody((MemoryStream)await upstreamResponse.Content.ReadAsStreamAsync().ConfigureAwait(false), upstreamResponse.Content.Headers);
}
entry.Response.Body = body.Length == 0 ? null : body;
entry.StatusCode = (int)upstreamResponse.StatusCode;
EntryRecordMode mode = GetRecordMode(incomingRequest);
if (mode != EntryRecordMode.DontRecord)
{
await session.Session.EntryLock.WaitAsync();
try
{
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Lock obtained. Adding entry {entry.RequestMethod} {entry.RequestUri} to session {recordingId}."));
session.Session.Entries.Add(entry);
}
finally
{
session.Session.EntryLock.Release();
}
Interlocked.Increment(ref Startup.RequestsRecorded);
}
if (mode == EntryRecordMode.RecordWithoutRequestBody)
{
entry.Request.Body = null;
}
outgoingResponse.StatusCode = (int)upstreamResponse.StatusCode;
foreach (var header in upstreamResponse.Headers.Concat(upstreamResponse.Content.Headers))
{
var values = new StringValues(header.Value.ToArray());
outgoingResponse.Headers.Append(header.Key, values);
entry.Response.Headers.Add(header.Key, values);
}
outgoingResponse.Headers.Remove("Transfer-Encoding");
if (entry.Response.Body?.Length > 0)
{
var bodyData = CompressionUtilities.CompressBody(entry.Response.Body, entry.Response.Headers);
if (entry.Response.Headers.ContainsKey("Content-Length")){
outgoingResponse.ContentLength = bodyData.Length;
}
await outgoingResponse.Body.WriteAsync(bodyData).ConfigureAwait(false);
}
}
public static EntryRecordMode GetRecordMode(HttpRequest request)
{
EntryRecordMode mode = EntryRecordMode.Record;
if (request.Headers.TryGetValue(SkipRecordingHeaderKey, out var values))
{
if (values.Count != 1)
{
throw new HttpException(
HttpStatusCode.BadRequest,
$"'{SkipRecordingHeaderKey}' should contain a single value set to either '{SkipRecordingRequestBody}' or " +
$"'{SkipRecordingRequestResponse}'");
}
string skipMode = values.First();
if (skipMode.Equals(SkipRecordingRequestResponse, StringComparison.OrdinalIgnoreCase))
{
mode = EntryRecordMode.DontRecord;
}
else if (skipMode.Equals(SkipRecordingRequestBody, StringComparison.OrdinalIgnoreCase))
{
mode = EntryRecordMode.RecordWithoutRequestBody;
}
else
{
throw new HttpException(
HttpStatusCode.BadRequest,
$"{skipMode} is not a supported value for header '{SkipRecordingHeaderKey}'." +
$"It should be either omitted from the request headers, or set to either '{SkipRecordingRequestBody}' " +
$"or '{SkipRecordingRequestResponse}'");
}
}
return mode;
}
public HttpRequestMessage CreateUpstreamRequest(HttpRequest incomingRequest, byte[] incomingBody)
{
var upstreamRequest = new HttpRequestMessage();
upstreamRequest.RequestUri = GetRequestUri(incomingRequest);
upstreamRequest.Method = new HttpMethod(incomingRequest.Method);
upstreamRequest.Content = new ReadOnlyMemoryContent(incomingBody);
foreach (var header in incomingRequest.Headers)
{
IEnumerable<string> values = header.Value;
// can't handle PROXY_CONNECTION right now.
if (s_excludedRequestHeaders.Contains(header.Key, StringComparer.OrdinalIgnoreCase))
{
continue;
}
if (!header.Key.StartsWith("x-recording"))
{
if (upstreamRequest.Headers.TryAddWithoutValidation(header.Key, values))
{
continue;
}
if (!upstreamRequest.Content.Headers.TryAddWithoutValidation(header.Key, values))
{
throw new HttpException(
HttpStatusCode.BadRequest,
$"Encountered an unexpected exception while mapping a content header during upstreamRequest creation. Header: \"{header.Key}\". Value: \"{String.Join(",", values)}\""
);
}
}
if (header.Key == "x-recording-upstream-host-header")
{
upstreamRequest.Headers.Host = header.Value;
}
}
return upstreamRequest;
}
#endregion
#region playback functionality
public async Task StartPlaybackAsync(string sessionId, HttpResponse outgoingResponse, RecordingType mode = RecordingType.FilePersisted, string assetsPath = null)
{
var id = Guid.NewGuid().ToString();
DebugLogger.LogTrace($"PLAYBACK START BEGIN {id}.");
ModifiableRecordSession session = new ModifiableRecordSession(SanitizerRegistry, id);
var auditEntry = new AuditLogItem(id, $"Starting playback for path {sessionId}, which will return recordingId {id}.");
if (mode == RecordingType.InMemory)
{
if (!InMemorySessions.TryGetValue(sessionId, out session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"There is no in-memory session with id {sessionId} available for playback retrieval.");
}
session.SourceRecordingId = sessionId;
}
else
{
await RestoreAssetsJson(assetsPath, true);
var path = await GetRecordingPath(sessionId, assetsPath);
var base64path = Convert.ToBase64String(Encoding.UTF8.GetBytes(path));
outgoingResponse.Headers.Append("x-base64-recording-file-location", base64path);
if (!File.Exists(path))
{
throw new TestRecordingMismatchException($"Recording file path {path} does not exist.");
}
using var stream = System.IO.File.OpenRead(path);
using var doc = await JsonDocument.ParseAsync(stream).ConfigureAwait(false);
session = new ModifiableRecordSession(RecordSession.Deserialize(doc.RootElement), SanitizerRegistry, id)
{
Path = path
};
session.AuditLog.Enqueue(auditEntry);
}
if (!PlaybackSessions.TryAdd(id, session))
{
throw new HttpException(HttpStatusCode.InternalServerError, $"Unexpectedly failed to add new playback session under id {id}.");
}
outgoingResponse.Headers.Append("x-recording-id", id);
var json = JsonSerializer.Serialize(session.Session.Variables);
outgoingResponse.Headers.Append("Content-Type", "application/json");
// Write to the response
await outgoingResponse.WriteAsync(json);
DebugLogger.LogTrace($"PLAYBACK START END {id}.");
}
public async Task StopPlayback(string recordingId, bool purgeMemoryStore = false)
{
var id = Guid.NewGuid().ToString();
DebugLogger.LogTrace($"PLAYBACK STOP BEGIN {id}.");
// obtain the playbacksession so we can get grab a lock on it. if there is a streaming response we will HAVE TO WAIT for that to complete
// before we finish
if (!PlaybackSessions.TryGetValue(recordingId, out var session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"There is no active playback session under recording id {recordingId}.");
}
await session.Session.EntryLock.WaitAsync();
try
{
if (!PlaybackSessions.TryRemove(recordingId, out var removedSession))
{
throw new HttpException(HttpStatusCode.BadRequest, $"There is no active playback session under recording id {recordingId}.");
}
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Lock obtained, stopping playback for {recordingId}."));
if (!AuditSessions.TryAdd(recordingId, session.AuditLog))
{
DebugLogger.LogError($"Unable to save audit log for {recordingId}");
}
if (!String.IsNullOrEmpty(session.SourceRecordingId) && purgeMemoryStore)
{
if (!InMemorySessions.TryRemove(session.SourceRecordingId, out var inMemorySession))
{
throw new HttpException(HttpStatusCode.InternalServerError, $"Unexpectedly failed to remove in-memory session {session.SourceRecordingId}.");
}
Interlocked.Add(ref Startup.RequestsRecorded, -1 * inMemorySession.Session.Entries.Count);
GC.Collect();
}
DebugLogger.LogTrace($"PLAYBACK STOP END {id}.");
}
finally
{
session.Session.EntryLock.Release();
}
}
public async Task HandlePlaybackRequest(string recordingId, HttpRequest incomingRequest, HttpResponse outgoingResponse)
{
if (!PlaybackSessions.TryGetValue(recordingId, out var session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"There is no active playback session under recording id {recordingId}.");
}
RecordEntry noBodyEntry = RecordingHandler.CreateNoBodyRecordEntry(incomingRequest);
session.AuditLog.Enqueue(new AuditLogItem(recordingId, noBodyEntry.RequestUri, noBodyEntry.RequestMethod.ToString()));
var sanitizers = await SanitizerRegistry.GetSanitizers(session);
if (!session.IsSanitized)
{
// we don't need to re-sanitize with recording-applicable sanitizers every time. just the very first one
await session.Session.EntryLock.WaitAsync();
try
{
if (!session.IsSanitized)
{
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"In 'one-time' sanitization for {recordingId}. I am applying {sanitizers.Count} sanitizers."));
await session.Session.Sanitize(sanitizers, false);
session.IsSanitized = true;
}
}
finally
{
session.Session.EntryLock.Release();
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Finished 'one-time' sanitization for {recordingId}. I applied {sanitizers.Count} sanitizers."));
}
}
DebugLogger.LogRequestDetails(incomingRequest, sanitizers);
var entry = (await CreateEntryAsync(incomingRequest).ConfigureAwait(false)).Item1;
await session.Session.EntryLock.WaitAsync();
try
{
// Session may be removed later, but only after response has been fully written
var match = session.Session.Lookup(entry, session.CustomMatcher ?? Matcher, sanitizers, remove: false, sessionId: recordingId);
foreach (ResponseTransform transform in Transforms.Concat(session.AdditionalTransforms))
{
transform.Transform(incomingRequest, match);
}
outgoingResponse.StatusCode = match.StatusCode;
foreach (var header in match.Response.Headers)
{
outgoingResponse.Headers.Append(header.Key, header.Value.ToArray());
}
outgoingResponse.Headers.Remove("Transfer-Encoding");
if (match.Response.Body?.Length > 0)
{
var bodyData = CompressionUtilities.CompressBody(match.Response.Body, match.Response.Headers);
if (match.Response.Headers.ContainsKey("Content-Length"))
{
outgoingResponse.ContentLength = bodyData.Length;
}
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Beginning body write for {recordingId}."));
await WriteBodyBytes(bodyData, session.PlaybackResponseTime, outgoingResponse);
}
Interlocked.Increment(ref Startup.RequestsPlayedBack);
// Only remove session once body has been written, to minimize probability client retries but test-proxy has already removed the session
var remove = true;
// If request contains "x-recording-remove: false", then request is not removed from session after playback.
// Used by perf tests to play back the same request multiple times.
if (incomingRequest.Headers.TryGetValue("x-recording-remove", out var removeHeader))
{
remove = bool.Parse(removeHeader);
}
if (remove)
{
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Now popping entry {entry.RequestMethod} {match.RequestUri} from entries for {recordingId}."));
await session.Session.Remove(match, shouldLock: false);
}
}
finally
{
session.Session.EntryLock.Release();
}
}
public byte[][] GetBatches(byte[] bodyData, int batchCount)
{
if (bodyData.Length == 0 || bodyData.Length < batchCount)
{
var result = new byte[1][];
result[0] = bodyData;
return result;
}
int chunkLength = bodyData.Length / batchCount;
int remainder = (bodyData.Length % batchCount);
var batches = new byte[batchCount + (remainder > 0 ? 1 : 0)][];
for(int i = 0; i < batches.Length; i++)
{
var calculatedChunkLength = ((i == batches.Length - 1) && (batches.Length > 1) && (remainder > 0)) ? remainder : chunkLength;
var batch = new byte[calculatedChunkLength];
Array.Copy(bodyData, i * chunkLength, batch, 0, calculatedChunkLength);
batches[i] = batch;
}
return batches;
}
public async Task WriteBodyBytes(byte[] bodyData, int playbackResponseTime, HttpResponse outgoingResponse)
{
if (playbackResponseTime > 0)
{
int batchCount = 10;
int sleepLength = playbackResponseTime / batchCount;
byte[][] chunks = GetBatches(bodyData, batchCount);
for(int i = 0; i < chunks.Length; i++)
{
var chunk = chunks[i];
await outgoingResponse.Body.WriteAsync(chunk).ConfigureAwait(false);
if (i != chunks.Length - 1)
{
await Task.Delay(sleepLength);
}
}
}
else
{
await outgoingResponse.Body.WriteAsync(bodyData).ConfigureAwait(false);
}
}
public static async Task<(RecordEntry, byte[])> CreateEntryAsync(HttpRequest request)
{
var entry = CreateNoBodyRecordEntry(request);
byte[] bytes = await ReadAllBytes(request.Body).ConfigureAwait(false);
entry.Request.Body = CompressionUtilities.DecompressBody(bytes, request.Headers);
return (entry, bytes);
}
public static RecordEntry CreateNoBodyRecordEntry(HttpRequest request)
{
var entry = new RecordEntry();
entry.RequestUri = GetRequestUri(request).AbsoluteUri;
entry.RequestMethod = new RequestMethod(request.Method);
foreach (var header in request.Headers)
{
if (IncludeHeader(header.Key))
{
entry.Request.Headers.Add(header.Key, header.Value.ToArray());
}
}
return entry;
}
#endregion
#region SetRecordingOptions and store functionality
public static string GetAssetsJsonLocation(string pathToAssetsJson, string contextDirectory)
{
if (pathToAssetsJson == null)
{
return null;
}
var path = pathToAssetsJson;
if (!Path.IsPathFullyQualified(pathToAssetsJson))
{
path = Path.Join(contextDirectory, pathToAssetsJson);
}
return path.Replace("\\", "/");
}
public async Task Restore(string pathToAssetsJson)
{
var resultingPath = await Store.Restore(pathToAssetsJson);
ContextDirectory = resultingPath;
}
public void SetRecordingOptions(IDictionary<string, object> options = null, string sessionId = null)
{
if (options != null)
{
if (options.Keys.Count == 0)
{
throw new HttpException(HttpStatusCode.BadRequest, "At least one key is expected in the body being passed to SetRecordingOptions.");
}
if (options.TryGetValue("HandleRedirects", out var handleRedirectsObj))
{
var handleRedirectsString = $"{handleRedirectsObj}";
if (bool.TryParse(handleRedirectsString, out var handleRedirectsBool))
{
HandleRedirects = handleRedirectsBool;
}
else if (handleRedirectsString.Equals("0", StringComparison.OrdinalIgnoreCase))
{
HandleRedirects = false;
}
else if (handleRedirectsString.Equals("1", StringComparison.OrdinalIgnoreCase))
{
HandleRedirects = true;
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, $"The value of key \"HandleRedirects\" MUST be castable to a valid boolean value. Unparsable Value: \"{handleRedirectsString}\".");
}
}
if (options.TryGetValue("ContextDirectory", out var sourceDirectoryObj))
{
var newSourceDirectory = sourceDirectoryObj.ToString();
if (!string.IsNullOrWhiteSpace(newSourceDirectory))
{
SetRecordingDirectory(newSourceDirectory);
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, "Users must provide a valid value to the key \"ContextDirectory\" in the recording options dictionary.");
}
}
if (options.TryGetValue("AssetsStore", out var assetsStoreObj))
{
var newAssetsStoreIdentifier = assetsStoreObj.ToString();
if (!string.IsNullOrWhiteSpace(newAssetsStoreIdentifier))
{
SetAssetsStore(newAssetsStoreIdentifier);
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, "Users must provide a valid value when providing the key \"AssetsStore\" in the recording options dictionary.");
}
}
if (options.TryGetValue("Transport", out var transportConventions))
{
if (transportConventions != null)
{
try
{
string transportObject;
if (transportConventions is JsonElement je)
{
transportObject = je.ToString();
}
else
{
throw new Exception("'Transport' object was not a JsonElement");
}
var serializerOptions = new JsonSerializerOptions
{
ReadCommentHandling = JsonCommentHandling.Skip,
AllowTrailingCommas = true,
};
var customizations = JsonSerializer.Deserialize<TransportCustomizations>(transportObject, serializerOptions);
SetTransportOptions(customizations, sessionId);
}
catch (HttpException)
{
throw;
}
catch (Exception e)
{
throw new HttpException(HttpStatusCode.BadRequest, $"Unable to deserialize the contents of the \"Transport\" key. Visible object: {transportConventions}. Json Deserialization Error: {e.Message}");
}
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, "Users must provide a valid value when providing the key \"Transport\" in the recording options dictionary.");
}
}
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, "When setting recording options, the request body is expected to be non-null and of type Dictionary<string, string>.");
}
}
public X509Certificate2 GetValidationCert(TransportCustomizations settings)
{
try
{
var span = new ReadOnlySpan<char>(settings.TLSValidationCert.ToCharArray());
return PemReader.LoadCertificate(span, null, PemReader.KeyType.Auto, true);
}
catch (Exception e)
{
throw new HttpException(HttpStatusCode.BadRequest, $"Unable to instantiate a valid cert from the value provided in Transport settings key \"TLSValidationCert\". Value: \"{settings.TLSValidationCert}\". Message: \"{e.Message}\".");
}
}
public HttpClientHandler GetTransport(bool allowAutoRedirect, TransportCustomizations customizations, bool insecure = false)
{
var clientHandler = new HttpClientHandler()
{
AllowAutoRedirect = allowAutoRedirect
};
if (customizations.Certificates != null)
{
foreach (var certPair in customizations.Certificates)
{
try
{
var cert = X509Certificate2.CreateFromPem(certPair.PemValue, certPair.PemKey);
cert = new X509Certificate2(cert.Export(X509ContentType.Pfx));
clientHandler.ClientCertificates.Add(cert);
}
catch (Exception e)
{
throw new HttpException(HttpStatusCode.BadRequest, $"Unable to instantiate a new X509 certificate from the provided value and key. Failure Message: \"{e.Message}\".");
}
}
}
if (customizations.TLSValidationCert != null && !insecure)
{
var ledgerCert = GetValidationCert(customizations);
X509Chain certificateChain = new();
certificateChain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
certificateChain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;
certificateChain.ChainPolicy.VerificationFlags = X509VerificationFlags.AllowUnknownCertificateAuthority;
certificateChain.ChainPolicy.VerificationTime = DateTime.Now;
certificateChain.ChainPolicy.UrlRetrievalTimeout = new TimeSpan(0, 0, 0);
certificateChain.ChainPolicy.ExtraStore.Add(ledgerCert);
clientHandler.ServerCertificateCustomValidationCallback = (HttpRequestMessage httpRequestMessage, X509Certificate2 cert, X509Chain x509Chain, SslPolicyErrors sslPolicyErrors) =>
{
if (!string.IsNullOrWhiteSpace(customizations.TSLValidationCertHost) && httpRequestMessage.RequestUri.Host != customizations.TSLValidationCertHost)
{
if (sslPolicyErrors == SslPolicyErrors.None)
{
return true;
}
return false;
}
else
{
bool isChainValid = certificateChain.Build(cert);
if (!isChainValid) return false;
var isCertSignedByTheTlsCert = certificateChain.ChainElements.Cast<X509ChainElement>()
.Any(x => x.Certificate.Thumbprint == ledgerCert.Thumbprint);
return isCertSignedByTheTlsCert;
}
};
}
else if (insecure)
{
clientHandler.ServerCertificateCustomValidationCallback = (_, _, _, _) => true;
}
return clientHandler;
}
public void SetTransportOptions(TransportCustomizations customizations, string sessionId)
{
var timeoutSpan = TimeSpan.FromSeconds(600);
// this will look a bit strange until we take care of #3488 due to the fact that this AllowAutoRedirect customizable from two places
if (!string.IsNullOrWhiteSpace(sessionId))
{
var customizedClientHandler = GetTransport(customizations.AllowAutoRedirect, customizations);
if (RecordingSessions.TryGetValue(sessionId, out var recordingSession))
{
recordingSession.Client = new HttpClient(customizedClientHandler)
{
Timeout = timeoutSpan
};
}
if (customizations.PlaybackResponseTime > 0)
{
if (PlaybackSessions.TryGetValue(sessionId, out var playbackSession))
{
playbackSession.PlaybackResponseTime = customizations.PlaybackResponseTime;
}
else
{
throw new HttpException(HttpStatusCode.BadRequest, $"Unable to set a transport customization on a recording session that is not active. Id: \"{sessionId}\"");
}
}
}
else
{
// after #3488 we will swap to a single client instead of both of these
var redirectableCustomizedHandler = GetTransport(true, customizations, Startup.Insecure);
var redirectlessCustomizedHandler = GetTransport(false, customizations, Startup.Insecure);
RedirectableClient = new HttpClient(redirectableCustomizedHandler)
{
Timeout = timeoutSpan
};
RedirectlessClient = new HttpClient(redirectlessCustomizedHandler)
{
Timeout = timeoutSpan
};
}
}
public void SetAssetsStore(string assetsStoreId)
{
Store = Resolver.ResolveStore(assetsStoreId);
}
public void SetRecordingDirectory(string targetDirectory)
{
try
{
// Given that it is perfectly valid to pass a directory that does not yet exist, we cannot get the file attributes to "properly"
// determine if an incoming path is a valid one via <attr>.HasFlag(FileAttributes.Directory). We can shorthand this by checking
// for a file extension.
if (Path.GetExtension(targetDirectory) != String.Empty)
{
targetDirectory = Path.GetDirectoryName(targetDirectory);
}
if (!String.IsNullOrEmpty(targetDirectory))
{
Directory.CreateDirectory(targetDirectory);
}
ContextDirectory = targetDirectory;
}
catch (Exception ex)
{
throw new HttpException(HttpStatusCode.BadRequest, $"Unable set proxy context to target directory \"{targetDirectory}\". Unhandled exception was: \"{ex.Message}\".");
}
}
#endregion
#region utility and common-use functions
public ModifiableRecordSession GetActiveSession(string recordingId)
{
if (PlaybackSessions.TryGetValue(recordingId, out var playbackSession))
{
return playbackSession;
}
if (RecordingSessions.TryGetValue(recordingId, out var recordingSession))
{
return recordingSession;
}
if (InMemorySessions.TryGetValue(recordingId, out var inMemSession))
{
return inMemSession;
}
throw new HttpException(HttpStatusCode.BadRequest, $"{recordingId} is not an active session for either record or playback. Check the value being passed and try again.");
}
public async Task<string> UnregisterSanitizer(string sanitizerId, string recordingId = null)
{
if (!string.IsNullOrWhiteSpace(recordingId))
{
var session = GetActiveSession(recordingId);
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Starting unregister of {sanitizerId}."));
return await SanitizerRegistry.Unregister(sanitizerId, session);
}
return await SanitizerRegistry.Unregister(sanitizerId);
}
public async Task<List<string>> RegisterSanitizers(List<RecordedTestSanitizer> sanitizers, string recordingId = null)
{
var registrations = new List<string>();
if (!string.IsNullOrWhiteSpace(recordingId))
{
var session = GetActiveSession(recordingId);
await session.Session.EntryLock.WaitAsync();
try
{
foreach (var sanitizer in sanitizers)
{
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Starting registration of sanitizer {sanitizer.GetType()}"));
registrations.Add(await SanitizerRegistry.Register(session, sanitizer, shouldLock: false));
}
}
finally
{
session.Session.EntryLock.Release();
}
}
else
{
await SanitizerRegistry.SessionSanitizerLock.WaitAsync();
try {
foreach (var sanitizer in sanitizers)
{
registrations.Add(await SanitizerRegistry.Register(sanitizer, shouldLock: false));
}
}
finally
{
SanitizerRegistry.SessionSanitizerLock.Release();
}
}
return registrations;
}
public async Task<string> RegisterSanitizer(RecordedTestSanitizer sanitizer, string recordingId = null)
{
if (!string.IsNullOrWhiteSpace(recordingId))
{
var session = GetActiveSession(recordingId);
session.AuditLog.Enqueue(new AuditLogItem(recordingId, $"Starting registration of sanitizer {sanitizer.GetType()}"));
return await SanitizerRegistry.Register(session, sanitizer);
}
return await SanitizerRegistry.Register(sanitizer);
}
public void AddTransformToRecording(string recordingId, ResponseTransform transform)
{
if (!PlaybackSessions.TryGetValue(recordingId, out var session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"{recordingId} is not an active playback session. Check the value being passed and try again.");
}
session.AdditionalTransforms.Add(transform);
}
public void SetMatcherForRecording(string recordingId, RecordMatcher matcher)
{
if (!PlaybackSessions.TryGetValue(recordingId, out var session))
{
throw new HttpException(HttpStatusCode.BadRequest, $"{recordingId} is not an active playback session. Check the value being passed and try again.");
}
session.CustomMatcher = matcher;
}
public async Task SetDefaultExtensions(string recordingId = null)
{
if (recordingId != null)
{
if (PlaybackSessions.TryGetValue(recordingId, out var playbackSession))
{
playbackSession.ResetExtensions(SanitizerRegistry);
}
if (RecordingSessions.TryGetValue(recordingId, out var recordSession))
{
recordSession.ResetExtensions(SanitizerRegistry);
}
if (InMemorySessions.TryGetValue(recordingId, out var inMemSession))
{
inMemSession.ResetExtensions(SanitizerRegistry);
}
}
else
{
var countPlayback = PlaybackSessions.Count;
var countInMem = InMemorySessions.Count;
var countRecording = RecordingSessions.Count;
var countTotal = countPlayback + countInMem + countRecording;
if (countTotal > 0)
{
StringBuilder sb = new StringBuilder();
sb.Append($"There are a total of {countTotal} active sessions. Remove these sessions before hitting Admin/Reset." + Environment.NewLine);
if (countPlayback > 0)
{
sb.Append("Active Playback Sessions: [");
lock (PlaybackSessions)
{
sb.Append(string.Join(", ", PlaybackSessions.Keys.ToArray()));
}
sb.Append("]. ");
}
if (countInMem > 0)
{
sb.Append("Active InMem Sessions: [");
lock (InMemorySessions)
{
sb.Append(string.Join(", ", InMemorySessions.Keys.ToArray()));
}
sb.Append("]. ");
}
if (countRecording > 0)
{
sb.Append($"{countRecording} Active Recording Sessions: [");
lock (RecordingSessions)
{
sb.Append(string.Join(", ", RecordingSessions.Keys.ToArray()));
}
sb.Append("]. ");
}
throw new HttpException(HttpStatusCode.BadRequest, sb.ToString());
}
await SanitizerRegistry.ResetSessionSanitizers();
Transforms = new List<ResponseTransform>
{
new StorageRequestIdTransform(),
new ClientIdTransform(),
new HeaderTransform("Retry-After", "0")
{
Condition = new ApplyCondition
{
ResponseHeader = new HeaderCondition
{
Key = "Retry-After"
}
}
}
};
Matcher = new RecordMatcher();
RedirectableClient = BaseRedirectableClient;
RedirectlessClient = BaseRedirectlessClient;
}
}
public async Task<string> GetRecordingPath(string file, string assetsPath = null)
{
var normalizedFileName = file.Replace('\\', '/');
if (String.IsNullOrWhiteSpace(file))
{
throw new HttpException(HttpStatusCode.BadRequest, $"Recording file value of {file} is invalid. Try again with a populated filename.");
}
var path = file;
// if an assets.json is provided, we have a bit of work to do here.
if (!string.IsNullOrWhiteSpace(assetsPath))
{
var contextDirectory = await Store.GetPath(assetsPath);
if (Path.IsPathFullyQualified(file)) {
throw new HttpException(
HttpStatusCode.BadRequest,
$"The path provided in the recording file value {file} is fully qualified. This is not allowed when an assets.json is provided."
);
}
path = Path.Join(contextDirectory, file);
}
// otherwise, it's a basic restore like we're used to
else
{
if (!Path.IsPathFullyQualified(file))
{
path = Path.Join(ContextDirectory, file);
}
}
return (path + (!path.EndsWith(".json") ? ".json" : String.Empty));
}
public static string GetHeader(HttpRequest request, string name, bool allowNulls = false)
{
if (!request.Headers.TryGetValue(name, out var value))
{
if (allowNulls)
{
return null;
}
throw new HttpException(HttpStatusCode.BadRequest, $"Expected header {name} is not populated in request.");
}
return value;
}
public static Uri GetRequestUri(HttpRequest request)
{
// Instead of obtaining the Path of the request from request.Path, we use this
// more complicated method obtaining the raw string from the httpcontext. Unfortunately,
// The native request functions implicitly decode the Path value. EG: "aa%27bb" is decoded into 'aa'bb'.
// Using the RawTarget PREVENTS this automatic decode. We still lean on the URI constructors
// to give us some amount of safety, but note that we explicitly disable escaping in that combination.
var rawTarget = request.HttpContext.Features.Get<IHttpRequestFeature>().RawTarget;
var hostValue = GetHeader(request, "x-recording-upstream-base-uri");
// it is easy to forget the x-recording-upstream-base-uri value
if (string.IsNullOrWhiteSpace(hostValue))
{
throw new HttpException(HttpStatusCode.BadRequest, $"The value present in header 'x-recording-upstream-base-uri' is not a valid hostname: {hostValue}.");
}
// The host value from the header should include scheme and port. EG:
// https://portal.azure.com/
// http://localhost:8080
// http://user:pass@localhost:8080/ <-- this should be EXTREMELY rare given it's extremely insecure
//
// The value from rawTarget is the _exact_ "rest of the URI" WITHOUT auto-decoding (as specified above) and could look like:
// ///request
// /hello/world?query=blah
// ""
// //hello/world
//
// We cannot use a URIBuilder to combine the hostValue and the rawTarget, as doing so results in auto-decoding of escaped
// characters that will BREAK the request that we actually wish to make.
//
// Given these limitations, and safe in the knowledge of both sides of this operation. We trim the trailing / off of the host,
// and string concatenate them together.
var rawUri = hostValue.TrimEnd('/') + rawTarget;
return new Uri(rawUri);
}
private static bool IncludeHeader(string header)
{
return !header.Equals("Host", StringComparison.OrdinalIgnoreCase)
&& !header.StartsWith("x-recording-", StringComparison.OrdinalIgnoreCase);
}
private static async Task<byte[]> ReadAllBytes(Stream stream)
{
using var memory = new MemoryStream();
using (stream)
{
await stream.CopyToAsync(memory).ConfigureAwait(false);
}
return memory.Length == 0 ? null : memory.ToArray();
}
#endregion
}
}