src/Microsoft.Extensions.Configuration.AzureAppConfiguration/AzureKeyVaultReference/AzureKeyVaultSecretProvider.cs (181 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
using Azure.Security.KeyVault.Secrets;
using Microsoft.Extensions.Configuration.AzureAppConfiguration.Extensions;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.Configuration.AzureAppConfiguration.AzureKeyVault
{
internal class AzureKeyVaultSecretProvider
{
private readonly AzureAppConfigurationKeyVaultOptions _keyVaultOptions;
private readonly IDictionary<string, SecretClient> _secretClients;
private readonly Dictionary<Uri, CachedKeyVaultSecret> _cachedKeyVaultSecrets;
private Uri _nextRefreshSourceId;
private DateTimeOffset? _nextRefreshTime;
public AzureKeyVaultSecretProvider(AzureAppConfigurationKeyVaultOptions keyVaultOptions = null)
{
_keyVaultOptions = keyVaultOptions ?? new AzureAppConfigurationKeyVaultOptions();
_cachedKeyVaultSecrets = new Dictionary<Uri, CachedKeyVaultSecret>();
_secretClients = new Dictionary<string, SecretClient>(StringComparer.OrdinalIgnoreCase);
if (_keyVaultOptions.SecretClients != null)
{
foreach (SecretClient client in _keyVaultOptions.SecretClients)
{
string keyVaultId = client.VaultUri.Host;
_secretClients[keyVaultId] = client;
}
}
}
public async Task<string> GetSecretValue(KeyVaultSecretIdentifier secretIdentifier, string key, string label, Logger logger, CancellationToken cancellationToken)
{
string secretValue = null;
if (_cachedKeyVaultSecrets.TryGetValue(secretIdentifier.SourceId, out CachedKeyVaultSecret cachedSecret) &&
(!cachedSecret.RefreshAt.HasValue || DateTimeOffset.UtcNow < cachedSecret.RefreshAt.Value))
{
return cachedSecret.SecretValue;
}
SecretClient client = GetSecretClient(secretIdentifier.SourceId);
if (client == null && _keyVaultOptions.SecretResolver == null)
{
throw new UnauthorizedAccessException("No key vault credential or secret resolver callback configured, and no matching secret client could be found.");
}
bool success = false;
try
{
if (client != null)
{
KeyVaultSecret secret = await client.GetSecretAsync(secretIdentifier.Name, secretIdentifier.Version, cancellationToken).ConfigureAwait(false);
logger.LogDebug(LogHelper.BuildKeyVaultSecretReadMessage(key, label));
logger.LogInformation(LogHelper.BuildKeyVaultSettingUpdatedMessage(key));
secretValue = secret.Value;
}
else if (_keyVaultOptions.SecretResolver != null)
{
secretValue = await _keyVaultOptions.SecretResolver(secretIdentifier.SourceId).ConfigureAwait(false);
}
cachedSecret = new CachedKeyVaultSecret(secretValue, secretIdentifier.SourceId);
success = true;
}
finally
{
SetSecretInCache(secretIdentifier.SourceId, key, cachedSecret, success);
}
return secretValue;
}
public bool ShouldRefreshKeyVaultSecrets()
{
return _nextRefreshTime.HasValue && _nextRefreshTime.Value < DateTimeOffset.UtcNow;
}
public void ClearCache()
{
var sourceIdsToRemove = new List<Uri>();
var utcNow = DateTimeOffset.UtcNow;
foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets)
{
if (secret.Value.LastRefreshTime + RefreshConstants.MinimumSecretRefreshInterval < utcNow)
{
sourceIdsToRemove.Add(secret.Key);
}
}
foreach (Uri sourceId in sourceIdsToRemove)
{
_cachedKeyVaultSecrets.Remove(sourceId);
}
if (_cachedKeyVaultSecrets.Any())
{
UpdateNextRefreshableSecretFromCache();
}
}
public void RemoveSecretFromCache(Uri sourceId)
{
_cachedKeyVaultSecrets.Remove(sourceId);
if (sourceId == _nextRefreshSourceId)
{
UpdateNextRefreshableSecretFromCache();
}
}
private SecretClient GetSecretClient(Uri secretUri)
{
string keyVaultId = secretUri.Host;
if (_secretClients.TryGetValue(keyVaultId, out SecretClient client))
{
return client;
}
if (_keyVaultOptions.Credential == null)
{
return null;
}
client = new SecretClient(
new Uri(secretUri.GetLeftPart(UriPartial.Authority)),
_keyVaultOptions.Credential,
_keyVaultOptions.ClientOptions);
_secretClients.Add(keyVaultId, client);
return client;
}
private void SetSecretInCache(Uri sourceId, string key, CachedKeyVaultSecret cachedSecret, bool success = true)
{
if (cachedSecret == null)
{
cachedSecret = new CachedKeyVaultSecret();
}
UpdateCacheExpirationTimeForSecret(key, cachedSecret, success);
_cachedKeyVaultSecrets[sourceId] = cachedSecret;
if (sourceId == _nextRefreshSourceId)
{
UpdateNextRefreshableSecretFromCache();
}
else if ((cachedSecret.RefreshAt.HasValue && _nextRefreshTime.HasValue && cachedSecret.RefreshAt.Value < _nextRefreshTime.Value)
|| (cachedSecret.RefreshAt.HasValue && !_nextRefreshTime.HasValue))
{
_nextRefreshSourceId = sourceId;
_nextRefreshTime = cachedSecret.RefreshAt.Value;
}
}
private void UpdateNextRefreshableSecretFromCache()
{
_nextRefreshSourceId = null;
_nextRefreshTime = DateTimeOffset.MaxValue;
foreach (KeyValuePair<Uri, CachedKeyVaultSecret> secret in _cachedKeyVaultSecrets)
{
if (secret.Value.RefreshAt.HasValue && secret.Value.RefreshAt.Value < _nextRefreshTime)
{
_nextRefreshTime = secret.Value.RefreshAt;
_nextRefreshSourceId = secret.Key;
}
}
if (_nextRefreshTime == DateTimeOffset.MaxValue)
{
_nextRefreshTime = null;
}
}
private void UpdateCacheExpirationTimeForSecret(string key, CachedKeyVaultSecret cachedSecret, bool success)
{
if (!_keyVaultOptions.SecretRefreshIntervals.TryGetValue(key, out TimeSpan cacheExpirationTime))
{
if (_keyVaultOptions.DefaultSecretRefreshInterval.HasValue)
{
cacheExpirationTime = _keyVaultOptions.DefaultSecretRefreshInterval.Value;
}
}
if (cacheExpirationTime > TimeSpan.Zero)
{
if (success)
{
cachedSecret.RefreshAttempts = 0;
cachedSecret.RefreshAt = DateTimeOffset.UtcNow.Add(cacheExpirationTime);
}
else
{
if (cachedSecret.RefreshAttempts < int.MaxValue)
{
cachedSecret.RefreshAttempts++;
}
cachedSecret.RefreshAt = DateTimeOffset.UtcNow.Add(cacheExpirationTime.CalculateBackoffTime(RefreshConstants.DefaultMinBackoff, RefreshConstants.DefaultMaxBackoff, cachedSecret.RefreshAttempts));
}
}
}
}
}