src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureKeyVaultReference/AzureKeyVaultSecretProvider.cs (181 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT license. // using Azure.Security.KeyVault.Secrets; using Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions; using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.AzureKeyVault { internal class AzureKeyVaultSecretProvider { private readonly AzureAppConfigurationKeyVaultOptions _keyVaultOptions; private readonly IDictionary<string, SecretClient> _secretClients; private readonly Dictionary<Uri, CachedKeyVaultSecret> _cachedKeyVaultSecrets; private Uri _nextRefreshSourceId; private DateTimeOffset? _nextRefreshTime; public AzureKeyVaultSecretProvider(AzureAppConfigurationKeyVaultOptions keyVaultOptions = null) { _keyVaultOptions = keyVaultOptions ?? new AzureAppConfigurationKeyVaultOptions(); _cachedKeyVaultSecrets = new Dictionary<Uri, CachedKeyVaultSecret>(); _secretClients = new Dictionary<string, SecretClient>(StringComparer.OrdinalIgnoreCase); if (_keyVaultOptions.SecretClients != null) { foreach (SecretClient client in _keyVaultOptions.SecretClients) { string keyVaultId = client.VaultUri.Host; _secretClients[keyVaultId] = client; } } } public async Task<string> GetSecretValue(KeyVaultSecretIdentifier secretIdentifier, string key, string label, Logger logger, CancellationToken cancellationToken) { string secretValue = null; if (_cachedKeyVaultSecrets.TryGetValue(secretIdentifier.SourceId, out CachedKeyVaultSecret cachedSecret) && (!cachedSecret.RefreshAt.HasValue || DateTimeOffset.UtcNow < cachedSecret.RefreshAt.Value)) { return cachedSecret.SecretValue; } SecretClient client = GetSecretClient(secretIdentifier.SourceId); if (client == null && _keyVaultOptions.SecretResolver == null) { throw new UnauthorizedAccessException("No key vault credential or secret resolver callback configured, and no matching secret client could be found."); } bool success = false; try { if (client != null) { KeyVaultSecret secret = await client.GetSecretAsync(secretIdentifier.Name, secretIdentifier.Version, cancellationToken).ConfigureAwait(false); logger.LogDebug(LogHelper.BuildKeyVaultSecretReadMessage(key, label)); logger.LogInformation(LogHelper.BuildKeyVaultSettingUpdatedMessage(key)); secretValue = secret.Value; } else if (_keyVaultOptions.SecretResolver != null) { secretValue = await _keyVaultOptions.SecretResolver(secretIdentifier.SourceId).ConfigureAwait(false); } cachedSecret = new CachedKeyVaultSecret(secretValue, secretIdentifier.SourceId); success = true; } finally { SetSecretInCache(secretIdentifier.SourceId, key, cachedSecret, success); } return secretValue; } public bool ShouldRefreshKeyVaultSecrets() { return _nextRefreshTime.HasValue && _nextRefreshTime.Value < DateTimeOffset.UtcNow; } public void ClearCache() { var sourceIdsToRemove = new List<Uri>(); var utcNow = DateTimeOffset.UtcNow; foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets) { if (secret.Value.LastRefreshTime + RefreshConstants.MinimumSecretRefreshInterval < utcNow) { sourceIdsToRemove.Add(secret.Key); } } foreach (Uri sourceId in sourceIdsToRemove) { _cachedKeyVaultSecrets.Remove(sourceId); } if (_cachedKeyVaultSecrets.Any()) { UpdateNextRefreshableSecretFromCache(); } } public void RemoveSecretFromCache(Uri sourceId) { _cachedKeyVaultSecrets.Remove(sourceId); if (sourceId == _nextRefreshSourceId) { UpdateNextRefreshableSecretFromCache(); } } private SecretClient GetSecretClient(Uri secretUri) { string keyVaultId = secretUri.Host; if (_secretClients.TryGetValue(keyVaultId, out SecretClient client)) { return client; } if (_keyVaultOptions.Credential == null) { return null; } client = new SecretClient( new Uri(secretUri.GetLeftPart(UriPartial.Authority)), _keyVaultOptions.Credential, _keyVaultOptions.ClientOptions); _secretClients.Add(keyVaultId, client); return client; } private void SetSecretInCache(Uri sourceId, string key, CachedKeyVaultSecret cachedSecret, bool success = true) { if (cachedSecret == null) { cachedSecret = new CachedKeyVaultSecret(); } UpdateCacheExpirationTimeForSecret(key, cachedSecret, success); _cachedKeyVaultSecrets[sourceId] = cachedSecret; if (sourceId == _nextRefreshSourceId) { UpdateNextRefreshableSecretFromCache(); } else if ((cachedSecret.RefreshAt.HasValue && _nextRefreshTime.HasValue && cachedSecret.RefreshAt.Value < _nextRefreshTime.Value) || (cachedSecret.RefreshAt.HasValue && !_nextRefreshTime.HasValue)) { _nextRefreshSourceId = sourceId; _nextRefreshTime = cachedSecret.RefreshAt.Value; } } private void UpdateNextRefreshableSecretFromCache() { _nextRefreshSourceId = null; _nextRefreshTime = DateTimeOffset.MaxValue; foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets) { if (secret.Value.RefreshAt.HasValue && secret.Value.RefreshAt.Value < _nextRefreshTime) { _nextRefreshTime = secret.Value.RefreshAt; _nextRefreshSourceId = secret.Key; } } if (_nextRefreshTime == DateTimeOffset.MaxValue) { _nextRefreshTime = null; } } private void UpdateCacheExpirationTimeForSecret(string key, CachedKeyVaultSecret cachedSecret, bool success) { if (!_keyVaultOptions.SecretRefreshIntervals.TryGetValue(key, out TimeSpan cacheExpirationTime)) { if (_keyVaultOptions.DefaultSecretRefreshInterval.HasValue) { cacheExpirationTime = _keyVaultOptions.DefaultSecretRefreshInterval.Value; } } if (cacheExpirationTime > TimeSpan.Zero) { if (success) { cachedSecret.RefreshAttempts = 0; cachedSecret.RefreshAt = DateTimeOffset.UtcNow.Add(cacheExpirationTime); } else { if (cachedSecret.RefreshAttempts < int.MaxValue) { cachedSecret.RefreshAttempts++; } cachedSecret.RefreshAt = DateTimeOffset.UtcNow.Add(cacheExpirationTime.CalculateBackoffTime(RefreshConstants.DefaultMinBackoff, RefreshConstants.DefaultMaxBackoff, cachedSecret.RefreshAttempts)); } } } } }