DeviceBridge/Services/EncryptionService.cs (96 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
using System;
using System.Collections.Generic;
using System.IO;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using DeviceBridge.Common.Exceptions;
using DeviceBridge.Providers;
using Microsoft.Azure.KeyVault.Models;
using NLog;
namespace DeviceBridge.Services
{
public class EncryptionService : IEncryptionService
{
private readonly ISecretsProvider _secretsProvider;
private IDictionary<string, SecretBundle> _encryptionKeys;
private string _latestKnownEncryptionKeyVersionId = null;
public EncryptionService(Logger logger, ISecretsProvider secretsProvider)
{
_secretsProvider = secretsProvider;
// Initialize secret cache
_encryptionKeys = _secretsProvider.GetEncryptionKeyVersions(logger).Result;
}
public async Task<string> Encrypt(Logger logger, string unencryptedString)
{
var keySecret = await GetEncryptionKey(logger);
var encryptionKey = keySecret.Value;
return $"{keySecret.SecretIdentifier.Version}-{EncryptString(unencryptedString, encryptionKey)}";
}
public async Task<string> Decrypt(Logger logger, string encryptedStringWithVersion)
{
var encryptedStringParts = encryptedStringWithVersion.Split('-');
if (encryptedStringParts.Length < 2)
{
throw new EncryptionException();
}
var keyVersion = encryptedStringParts[0];
var encryptedString = encryptedStringParts[1];
var keySecret = await GetEncryptionKey(logger, keyVersion);
var encryptionKey = keySecret.Value;
return DecryptString(encryptedString, encryptionKey);
}
private static string EncryptString(string plainText, string stringKey)
{
var key = Encoding.ASCII.GetBytes(stringKey);
using var aes = Aes.Create();
aes.Key = key;
aes.GenerateIV();
var encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
using var memoryStream = new MemoryStream();
using var cryptoStream = new CryptoStream(memoryStream, encryptor, CryptoStreamMode.Write);
using var streamWriter = new StreamWriter(cryptoStream);
streamWriter.Write(plainText);
streamWriter.Dispose();
return $"{Convert.ToBase64String(aes.IV)}:{Convert.ToBase64String(memoryStream.ToArray())}";
}
private static string DecryptString(string encryptedStringWithIv, string stringKey)
{
var key = Encoding.ASCII.GetBytes(stringKey);
using var aes = Aes.Create();
aes.Key = key;
var encryptedStringWithIvParts = encryptedStringWithIv.Split(':');
var iv = System.Convert.FromBase64String(encryptedStringWithIvParts[0]);
aes.IV = iv;
ICryptoTransform decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
using var memoryStream = new MemoryStream(Convert.FromBase64String(encryptedStringWithIvParts[1]));
using var cryptoStream = new CryptoStream(memoryStream, decryptor, CryptoStreamMode.Read);
using var streamReader = new StreamReader(cryptoStream);
return streamReader.ReadToEnd();
}
private async Task<SecretBundle> GetEncryptionKey(Logger logger, string version = null)
{
if (version == null && _latestKnownEncryptionKeyVersionId != null && _encryptionKeys.ContainsKey(_latestKnownEncryptionKeyVersionId))
{
// Use latest cached version
SecretBundle cachedValue;
_encryptionKeys.TryGetValue(_latestKnownEncryptionKeyVersionId, out cachedValue);
return cachedValue;
}
if (version != null && _encryptionKeys.ContainsKey(version))
{
// Used cached key
SecretBundle cachedValue;
_encryptionKeys.TryGetValue(version, out cachedValue);
return cachedValue;
}
// Get latest version from KV and cache
var foundKey = await _secretsProvider.GetEncryptionKey(logger, version);
if (!_encryptionKeys.ContainsKey(foundKey.SecretIdentifier.Version))
{
_encryptionKeys.Add(foundKey.SecretIdentifier.Version, foundKey);
}
if (version == null)
{
_latestKnownEncryptionKeyVersionId = foundKey.SecretIdentifier.Version;
}
return foundKey;
}
}
}