DeviceBridge/Startup.cs (130 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Reflection;
using System.Threading.Tasks;
using DeviceBridge.Common;
using DeviceBridge.Common.Authentication;
using DeviceBridge.Models;
using DeviceBridge.Providers;
using DeviceBridge.Services;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc.Authorization;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using NLog;
using Polly;
using Polly.Extensions.Http;
using Swashbuckle.AspNetCore.SwaggerGen;
namespace DeviceBridge
{
/// <summary>Class Startup.</summary>
public class Startup
{
private static Logger _logger = LogManager.GetCurrentClassLogger();
/// <summary>Initializes a new instance of the <see cref="Startup"/> class.</summary>
/// <param name="configuration">The configuration.</param>
public Startup(IConfiguration configuration)
{
Configuration = configuration;
}
public IConfiguration Configuration { get; }
/// <summary>This method gets called by the runtime. Use this method to add services to the container.</summary>
/// <param name="services">The services.</param>
public void ConfigureServices(IServiceCollection services)
{
_logger.Info("Configuring services");
string kvUrl = Environment.GetEnvironmentVariable("KV_URL");
// Build cache from Key Vault
var secretsService = new SecretsProvider(kvUrl);
var idScope = secretsService.GetIdScopeAsync(_logger).Result;
var sasKey = secretsService.GetIotcSasKeyAsync(_logger).Result;
var sqlConnectionString = Utils.GetSqlConnectionString(_logger, secretsService);
// Override defaults
var customMaxPoolSize = Environment.GetEnvironmentVariable("MAX_POOL_SIZE");
var customConnectionBatchSize = Environment.GetEnvironmentVariable("DEVICE_CONNECTION_BATCH_SIZE");
var customConnectionBatchIntervalMs = Environment.GetEnvironmentVariable("DEVICE_CONNECTION_BATCH_INTERVAL_MS");
uint maxPoolSize = (customMaxPoolSize != null && customMaxPoolSize != string.Empty) ? Convert.ToUInt32(customMaxPoolSize, 10) : ConnectionManager.DeafultMaxPoolSize;
uint rampupBatchSize = (customConnectionBatchSize != null && customConnectionBatchSize != string.Empty) ? Convert.ToUInt32(customConnectionBatchSize, 10) : SubscriptionScheduler.DefaultConnectionBatchSize;
uint rampupBatchIntervalMs = (customConnectionBatchIntervalMs != null && customConnectionBatchIntervalMs != string.Empty) ? Convert.ToUInt32(customConnectionBatchIntervalMs, 10) : SubscriptionScheduler.DefaultConnectionBatchIntervalMs;
_logger.SetProperty("idScope", idScope);
_logger.SetProperty("cv", Guid.NewGuid()); // CV for all background operations
services.AddHttpContextAccessor();
// Start services
services.AddSingleton<ISecretsProvider>(secretsService);
services.AddSingleton(_logger);
services.AddSingleton<IEncryptionService, EncryptionService>();
services.AddSingleton<IStorageProvider>(provider => new StorageProvider(sqlConnectionString, provider.GetRequiredService<IEncryptionService>()));
services.AddSingleton<IConnectionManager>(provider => new ConnectionManager(provider.GetRequiredService<Logger>(), idScope, sasKey, maxPoolSize, provider.GetRequiredService<IStorageProvider>()));
services.AddSingleton<ISubscriptionCallbackFactory, SubscriptionCallbackFactory>();
services.AddSingleton<IConnectionStatusSubscriptionService, ConnectionStatusSubscriptionService>();
services.AddSingleton<IDataSubscriptionService, DataSubscriptionService>();
services.AddSingleton<ISubscriptionScheduler>(provider => new SubscriptionScheduler(provider.GetRequiredService<Logger>(), provider.GetRequiredService<IConnectionManager>(), provider.GetRequiredService<IStorageProvider>(), provider.GetRequiredService<ISubscriptionCallbackFactory>(), rampupBatchSize, rampupBatchIntervalMs));
services.AddSingleton<IBridgeService, BridgeService>();
services.AddHttpClient("RetryClient").AddPolicyHandler(GetRetryPolicy(_logger));
services.AddHostedService<ExpiredConnectionCleanupHostedService>();
services.AddHostedService<SubscriptionStartupHostedService>();
services.AddHostedService<SubscriptionSchedulerHostedService>();
services.AddHostedService<HubCacheGcHostedService>();
services.AddAuthentication(o =>
{
o.DefaultScheme = SchemesNamesConst.TokenAuthenticationDefaultScheme;
})
.AddScheme<TokenAuthenticationOptions, TokenAuthenticationHandler>(SchemesNamesConst.TokenAuthenticationDefaultScheme, o => { });
services.AddControllers(options =>
{
options.Filters.Add(new AuthorizeFilter());
});
services.AddHealthChecks();
services.AddSwaggerGen(options =>
{
// Set XML comments.
var xmlFile = $"{Assembly.GetExecutingAssembly().GetName().Name}.xml";
var xmlPath = Path.Combine(AppContext.BaseDirectory, xmlFile);
options.IncludeXmlComments(xmlPath);
options.CustomOperationIds(apiDesc => apiDesc.TryGetMethodInfo(out MethodInfo methodInfo) ? methodInfo.Name : null);
// Type mappers for custom serialization.
options.MapType(typeof(DeviceSubscriptionType), () => DeviceSubscriptionType.Schema);
options.MapType(typeof(DeviceTwin), () => DeviceTwin.Schema);
});
}
/// <summary>This method gets called by the runtime. Use this method to configure the HTTP request pipeline..</summary>
/// <param name="app">The application.</param>
/// <param name="env">The env.</param>
/// <param name="lifetime">The lifetime.</param>
/// <param name="connectionManager">The connection manager.</param>
public void Configure(IApplicationBuilder app, IWebHostEnvironment env, IHostApplicationLifetime lifetime, IConnectionManager connectionManager)
{
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
}
app.UseSwagger(c =>
{
c.SerializeAsV2 = true;
});
app.UseMiddleware<RequestLoggingMiddleware>();
app.UseMiddleware<ExceptionHandlingMiddleware>();
app.UseRouting();
app.UseAuthentication();
app.UseAuthorization();
app.UseEndpoints(endpoints =>
{
endpoints.MapControllers();
endpoints.MapHealthChecks("/health");
});
}
/// <summary>
/// <para>Gets the retry policy, used in HttpClient.</para>
/// </summary>
/// <returns>IAsyncPolicy<HttpResponseMessage>.</returns>
private static IAsyncPolicy<HttpResponseMessage> GetRetryPolicy(Logger logger)
{
// Handles 5XX, 408 and 429 status codes.
return HttpPolicyExtensions
.HandleTransientHttpError()
.OrResult(msg => msg.StatusCode == (HttpStatusCode)429)
.WaitAndRetryAsync(
retryCount: Convert.ToInt32(Environment.GetEnvironmentVariable("HTTP_RETRY_LIMIT")),
sleepDurationProvider: (retryCount, response, context) =>
{
// Observe server Retry-After if applicable
IEnumerable<string> retryAfterValues;
logger.Info($"HTTP client retrying: {response.Result.RequestMessage}.");
if (response.Result.Headers.TryGetValues("Retry-After", out retryAfterValues))
{
return TimeSpan.FromSeconds(Convert.ToDouble(retryAfterValues.FirstOrDefault()));
}
return TimeSpan.FromSeconds(Math.Pow(2, retryCount));
},
onRetryAsync: async (response, timespan, retryCount, context) =>
{
await Task.CompletedTask;
});
}
}
}