tools/test-proxy/Azure.Sdk.Tools.TestProxy/Startup.cs (298 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using System;
using System.Threading;
using System.Threading.Tasks;
using System.IO;
using System.Text.RegularExpressions;
using Azure.Sdk.Tools.TestProxy.Common;
using Microsoft.Extensions.Logging;
using System.Reflection;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Azure.Sdk.Tools.TestProxy.Store;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.CommandLine;
using Azure.Sdk.Tools.TestProxy.CommandOptions;
using System.Text.Json;
using Microsoft.Extensions.Logging.Console;
using Microsoft.Extensions.Options;
using Azure.Sdk.Tools.TestProxy.Common.AutoShutdown;
namespace Azure.Sdk.Tools.TestProxy
{
[ExcludeFromCodeCoverage]
public sealed class Startup
{
internal static int RequestsRecorded;
internal static int RequestsPlayedBack;
private static bool _insecure;
internal static bool Insecure => _insecure;
public Startup(IConfiguration configuration) { }
public static string TargetLocation;
public static StoreResolver Resolver;
public static IAssetsStore DefaultStore;
public static string[] storedArgs;
private static string resolveRepoLocation(string storageLocation = null)
{
var envValue = Environment.GetEnvironmentVariable("TEST_PROXY_FOLDER");
return storageLocation ?? envValue ?? Directory.GetCurrentDirectory();
}
/// <summary>
/// test-proxy
/// </summary>
/// <param name="args">CommandLineParser arguments. In server mode use double dash '--' and everything after that becomes additional arguments to Host.CreateDefaultBuilder. Ex. -- arg1 value1 arg2 value2 </param>
public static async Task Main(string[] args = null)
{
storedArgs = args;
var rootCommand = OptionsGenerator.GenerateCommandLineOptions(Run);
var resultCode = await rootCommand.InvokeAsync(args);
Environment.Exit(resultCode);
}
private static async Task<int> Run(object commandObj)
{
var assembly = System.Reflection.Assembly.GetExecutingAssembly();
var semanticVersion = assembly.GetCustomAttribute<AssemblyInformationalVersionAttribute>().InformationalVersion;
System.Console.WriteLine($"Running proxy version is Azure.Sdk.Tools.TestProxy {semanticVersion}");
int returnCode = 0;
new GitProcessHandler().VerifyGitMinVersion();
DefaultOptions defaultOptions = (DefaultOptions)commandObj;
TargetLocation = resolveRepoLocation(defaultOptions.StorageLocation);
Resolver = new StoreResolver();
DefaultStore = Resolver.ResolveStore(defaultOptions.StoragePlugin ?? "GitStore");
var assetsJson = string.Empty;
switch (commandObj)
{
case ConfigLocateOptions configOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(configOptions.AssetsJsonPath, TargetLocation);
System.Console.WriteLine(await DefaultStore.GetPath(assetsJson));
break;
case ConfigShowOptions configOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(configOptions.AssetsJsonPath, TargetLocation);
using(var f = File.OpenRead(assetsJson))
{
using var json = JsonDocument.Parse(f);
System.Console.WriteLine(JsonSerializer.Serialize(json, new JsonSerializerOptions { WriteIndented = true }));
}
break;
case ConfigCreateOptions configOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(configOptions.AssetsJsonPath, TargetLocation);
throw new NotImplementedException("Interactive creation of assets.json feature is not yet implemented.");
case ConfigOptions configOptions:
System.Console.WriteLine("Config verb requires a subcommand after the \"config\" verb.\n\nCorrect Usage: \"Azure.Sdk.Tools.TestProxy config locate|show|create -a path/to/assets.json\"");
break;
case StartOptions startOptions:
StartServer(startOptions);
break;
case PushOptions pushOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(pushOptions.AssetsJsonPath, TargetLocation);
await DefaultStore.Push(assetsJson);
break;
case ResetOptions resetOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(resetOptions.AssetsJsonPath, TargetLocation);
await DefaultStore.Reset(assetsJson);
break;
case RestoreOptions restoreOptions:
DefaultStore.SetStoreExceptionMode(false);
assetsJson = RecordingHandler.GetAssetsJsonLocation(restoreOptions.AssetsJsonPath, TargetLocation);
await DefaultStore.Restore(assetsJson);
break;
case DefaultOptions defaultOpts:
StartServer(new StartOptions()
{
AdditionalArgs = new string[] { },
StorageLocation = defaultOpts.StorageLocation,
StoragePlugin = defaultOpts.StoragePlugin,
Insecure = false,
AutoShutdownTime = -1,
Dump = false
});
break;
default:
throw new ArgumentException($"Unable to parse the argument set: {string.Join(" ", storedArgs)}");
}
return returnCode;
}
private static void StartServer(StartOptions startOptions)
{
_insecure = startOptions.Insecure;
Regex.CacheSize = 0;
var statusThreadCts = new CancellationTokenSource();
var statusThread = PrintStatus(
() => $"[{DateTime.UtcNow.ToString("HH:mm:ss")}] Recorded: {RequestsRecorded}\tPlayed Back: {RequestsPlayedBack}",
newLine: true, statusThreadCts.Token);
var host = Host.CreateDefaultBuilder((startOptions.AdditionalArgs??new string[] { }).ToArray());
host.ConfigureWebHostDefaults(
builder =>
builder.UseStartup<Startup>()
// ripped directly from implementation of ConfigureWebDefaults@https://github.dev/dotnet/aspnetcore/blob/a779227cc2694a50b074a097889ed9e80d15cd77/src/DefaultBuilder/src/WebHost.cs#L176
.ConfigureLogging((hostBuilder, loggingBuilder) =>
{
loggingBuilder.ClearProviders();
loggingBuilder.AddConfiguration(hostBuilder.Configuration.GetSection("Logging"));
if (!startOptions.UniversalOutput)
{
loggingBuilder.AddConsole(options =>
{
options.LogToStandardErrorThreshold = LogLevel.Error;
});
}
loggingBuilder.AddSimpleConsole(options =>
{
options.TimestampFormat = "[HH:mm:ss] ";
});
loggingBuilder.AddDebug();
loggingBuilder.AddEventSourceLogger();
})
.ConfigureKestrel(kestrelServerOptions =>
{
kestrelServerOptions.ConfigureEndpointDefaults(lo => lo.Protocols = HttpProtocols.Http1);
// default minimum rate is 240 bytes per second with 5 second grace period. Bumping to 50bps with a graceperiod of 20 seconds.
kestrelServerOptions.Limits.MinRequestBodyDataRate = new MinDataRate(bytesPerSecond: 50, gracePeriod: TimeSpan.FromSeconds(20));
})
);
var app = host.Build();
var shutdownService = app.Services.GetRequiredService<ShutdownConfiguration>();
if (startOptions.AutoShutdownTime > -1)
{
shutdownService.EnableAutoShutdown = true;
shutdownService.TimeoutInSeconds = startOptions.AutoShutdownTime;
// start the first iteration of the shutdown timer
app.Services.GetRequiredService<ShutdownTimer>().ResetTimer();
}
if (startOptions.Dump)
{
var config = app.Services?.GetService<IConfiguration>();
System.Console.WriteLine("Dumping Resolved Configuration Values:");
if (config != null)
{
foreach (var c in config.AsEnumerable())
{
System.Console.WriteLine(c.Key + " = " + c.Value);
}
}
}
app.Run();
statusThreadCts.Cancel();
statusThread.Join();
}
public void ConfigureServices(IServiceCollection services)
{
services.AddCors(options =>
{
options.AddPolicy(name: "DefaultPolicy",
builder =>
{
builder.AllowAnyHeader()
.AllowAnyMethod()
.AllowAnyOrigin()
.WithExposedHeaders("*");
});
});
services.AddControllers(options =>
{
options.InputFormatters.Add(new EmptyBodyFormatter());
});
services.AddControllersWithViews();
services.AddRazorPages();
var singletonRecordingHandler = new RecordingHandler(
TargetLocation,
store: DefaultStore,
storeResolver: Resolver
);
services.AddSingleton<RecordingHandler>(singletonRecordingHandler);
services.AddSingleton<ShutdownConfiguration>();
services.AddSingleton<ShutdownTimer>();
}
public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILoggerFactory loggerFactory)
{
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
}
app.UseCors("DefaultPolicy");
app.UseMiddleware<HttpExceptionMiddleware>();
app.UseMiddleware<ShutdownTimerMiddleware>();
DebugLogger.ConfigureLogger(loggerFactory);
MapRecording(app);
app.UseRouting();
app.UseEndpoints(endpoints => endpoints.MapControllers());
}
// Route requests with header x-recording-mode = X to X.HandleRequest
// These are requests to be recorded or played back.
private void MapRecording(IApplicationBuilder app)
{
foreach (var controller in new[] { "playback", "record" })
{
app.MapWhen(
context =>
controller.Equals(
GetRecordingMode(context),
StringComparison.OrdinalIgnoreCase),
app =>
{
app.UseRouting();
app.UseEndpoints(
endpoints => endpoints.MapFallbackToController(
"{*path}", "HandleRequest", controller));
});
}
}
private static string GetRecordingMode(HttpContext context)
{
if (!context.Request.Headers.TryGetValue("x-recording-mode", out var values) || values.Count != 1)
{
return null;
}
return values[0];
}
// Run in dedicated thread instead of using async/await in ThreadPool, to ensure this thread has priority
// and never fails to run to due ThreadPool starvation.
private static Thread PrintStatus(Func<object> status, bool newLine, CancellationToken token, int intervalSeconds = 1)
{
var thread = new Thread(() =>
{
bool needsExtraNewline = false;
while (!token.IsCancellationRequested)
{
try
{
Task.Delay(TimeSpan.FromSeconds(intervalSeconds), token).Wait();
}
catch (Exception e) when (ContainsOperationCanceledException(e))
{
}
var obj = status();
if (newLine)
{
System.Console.WriteLine(obj);
}
else
{
System.Console.Write(obj);
needsExtraNewline = true;
}
}
if (needsExtraNewline)
{
System.Console.WriteLine();
}
System.Console.WriteLine();
});
thread.Start();
return thread;
}
private static bool ContainsOperationCanceledException(Exception e)
{
if (e is OperationCanceledException)
{
return true;
}
else if (e.InnerException != null)
{
return ContainsOperationCanceledException(e.InnerException);
}
else
{
return false;
}
}
}
}