Notation.Plugin.AzureKeyVault/KeyVault/KeyVaultClient.cs (153 lines of code) (raw):
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography.X509Certificates;
using Azure.Core;
using Azure.Security.KeyVault.Certificates;
using Azure.Security.KeyVault.Keys.Cryptography;
using Azure.Security.KeyVault.Secrets;
using Notation.Plugin.AzureKeyVault.Certificate;
using Notation.Plugin.Protocol;
[assembly: InternalsVisibleTo("Notation.Plugin.AzureKeyVault.Tests")]
namespace Notation.Plugin.AzureKeyVault.Client
{
public interface IKeyVaultClient
{
/// <summary>
/// Sign the payload with the specified algorithm.
/// </summary>
public Task<byte[]> SignAsync(SignatureAlgorithm algorithm, byte[] payload);
/// <summary>
/// Get the certificate from KeyVault.
/// </summary>
public Task<X509Certificate2> GetCertificateAsync();
/// <summary>
/// Get the certificate chain from KeyVault.
/// </summary>
public Task<X509Certificate2Collection> GetCertificateChainAsync();
}
public class KeyVaultClient : IKeyVaultClient
{
/// <summary>
/// A helper record to store KeyVault metadata.
/// </summary>
private record KeyVaultMetadata(string KeyVaultUrl, string Name, string? Version);
// Certificate client (lazy initialization)
// Protected for unit test
protected Lazy<CertificateClient> _certificateClient;
// Cryptography client (lazy initialization)
protected Lazy<CryptographyClient> _cryptoClient;
// Secret client (lazy initialization)
protected Lazy<SecretClient> _secretClient;
// Key name or certificate name
private string _name;
// Key version or certificate version
private string? _version;
// Key identifier (e.g. https://<vaultname>.vault.azure.net/keys/<name>/<version>)
private string _keyId;
// Internal getters for unit test
internal string Name => _name;
internal string? Version => _version;
internal string KeyId => _keyId;
/// <summary>
/// Constructor to create AzureKeyVault object from keyVaultUrl, name
/// and version.
/// </summary>
public KeyVaultClient(string keyVaultUrl, string name, string? version, TokenCredential credential)
{
if (string.IsNullOrEmpty(keyVaultUrl))
{
throw new ValidationException("Key vault URL must not be null or empty");
}
if (string.IsNullOrEmpty(name))
{
throw new ValidationException("Key name must not be null or empty");
}
if (version != null && version == string.Empty)
{
throw new ValidationException("Key version must not be empty");
}
_name = name;
_version = version;
_keyId = $"{keyVaultUrl}/keys/{name}";
if (version != null)
{
_keyId = $"{_keyId}/{version}";
}
// initialize credential and lazy clients
_certificateClient = new Lazy<CertificateClient>(() => new CertificateClient(new Uri(keyVaultUrl), credential));
_cryptoClient = new Lazy<CryptographyClient>(() => new CryptographyClient(new Uri(_keyId), credential));
_secretClient = new Lazy<SecretClient>(() => new SecretClient(new Uri(keyVaultUrl), credential));
}
/// <summary>
/// Constructor to create AzureKeyVault object from key identifier or
/// certificate identifier.
/// </summary>
/// <param name="id">
/// Key identifier or certificate identifier. (e.g. https://<vaultname>.vault.azure.net/keys/<name>/<version>)
/// </param>
/// <param name="credential">
/// TokenCredential object to authenticate with Azure Key Vault.
/// </param>
public KeyVaultClient(string id, TokenCredential credential) : this(ParseId(id), credential) { }
/// <summary>
/// A helper constructor to create KeyVaultClient from KeyVaultMetadata.
/// </summary>
private KeyVaultClient(KeyVaultMetadata metadata, TokenCredential credential)
: this(metadata.KeyVaultUrl, metadata.Name, metadata.Version, credential) { }
/// <summary>
/// A helper function to parse key identifier or certificate identifier
/// and return KeyVaultMetadata.
/// </summary>
private static KeyVaultMetadata ParseId(string id)
{
if (string.IsNullOrEmpty(id))
{
throw new ValidationException("Input passed to \"--id\" must not be empty");
}
var uri = new Uri(id.TrimEnd('/'));
// Validate uri
if (uri.Segments.Length < 3 || uri.Segments.Length > 4)
{
throw new ValidationException("Invalid input passed to \"--id\". Please follow this format to input the ID \"https://<vault-name>.vault.azure.net/certificates/<certificate-name>/[certificate-version]\"");
}
var type = uri.Segments[1].TrimEnd('/');
if (type != "keys" && type != "certificates")
{
throw new ValidationException($"Unsupported key vualt object type {type}.");
}
if (uri.Scheme != "https")
{
throw new ValidationException($"Unsupported scheme {uri.Scheme}. The scheme must be https.");
}
string? version = null;
if (uri.Segments.Length == 4)
{
version = uri.Segments[3].TrimEnd('/');
}
return new KeyVaultMetadata(
KeyVaultUrl: $"{uri.Scheme}://{uri.Host}",
Name: uri.Segments[2].TrimEnd('/'),
Version: version
);
}
/// <summary>
/// Sign the payload and return the signature.
/// </summary>
public async Task<byte[]> SignAsync(SignatureAlgorithm algorithm, byte[] payload)
{
var signResult = await _cryptoClient.Value.SignDataAsync(algorithm, payload);
if (!string.IsNullOrEmpty(_version) && signResult.KeyId != _keyId)
{
throw new PluginException($"Invalid key identifier. User required {_keyId} does not match {signResult.KeyId} in response. Please ensure the provided key identifier is correct.");
}
if (signResult.Algorithm != algorithm)
{
throw new PluginException($"Invalid signature algorithm. The user provides {algorithm} but the response contains {signResult.Algorithm} as the algorithm");
}
return signResult.Signature;
}
/// <summary>
/// Get the certificate from the key vault.
/// </summary>
public async Task<X509Certificate2> GetCertificateAsync()
{
KeyVaultCertificate cert;
if (string.IsNullOrEmpty(_version))
{
// If the version is not specified, get the latest version
cert = (await _certificateClient.Value.GetCertificateAsync(_name)).Value;
}
else
{
cert = (await _certificateClient.Value.GetCertificateVersionAsync(_name, _version)).Value;
// If the version is invalid, the cert will be fallback to
// the latest. So if the version is not the same as the
// requested version, it means the version is invalid.
if (cert.Properties.Version != _version)
{
throw new PluginException($"The version specified in the request is {_version} but the version retrieved from Azure Key Vault is {cert.Properties.Version}. Please ensure the version is correct.");
}
}
return new X509Certificate2(cert.Cer);
}
/// <summary>
/// Get the certificate chain from the key vault with GetSecret permission.
/// </summary>
public async Task<X509Certificate2Collection> GetCertificateChainAsync()
{
var secret = await _secretClient.Value.GetSecretAsync(_name, _version);
var chain = new X509Certificate2Collection();
var contentType = secret.Value.Properties.ContentType;
var secretValue = secret.Value.Value;
switch (contentType)
{
case "application/x-pkcs12":
// If the secret is a PKCS12 file, decode the base64 encoding
// Import will reverse the order of the certificates
// in the chain
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
// macOS doesn't support non-encrypted MAC
// https://github.com/dotnet/runtime/issues/23635
chain.Import(
rawData: Pkcs12.ReEncode(Convert.FromBase64String(secretValue)),
password: null,
keyStorageFlags: X509KeyStorageFlags.DefaultKeySet);
}
else
{
chain.Import(
rawData: Convert.FromBase64String(secretValue),
password: null,
keyStorageFlags: X509KeyStorageFlags.EphemeralKeySet);
}
break;
case "application/x-pem-file":
// If the secret is a PEM file, parse the PEM content directly
chain.ImportFromPem(secretValue.ToCharArray());
break;
default:
throw new ValidationException($"Unsupported secret content type: {contentType}");
}
return chain;
}
}
}