tools/test-proxy/Azure.Sdk.Tools.TestProxy/Vendored/LightweightPkcs8Decoder.cs (263 lines of code) (raw):

using System; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Linq; using System.Security.Cryptography; using System.Text; namespace Azure.Sdk.Tools.TestProxy.Vendored { /// This code was ripped directly from https://github.com/Azure/azure-sdk-for-net/blob/873d4dc419512f42b9c70d104bdcc1983badfd1b/sdk/core/Azure.Core/src/Shared/LightweightPkcs8Decoder.cs /// <summary> /// This is a very targeted PKCS#8 decoder for use when reading a PKCS# encoded RSA private key from an /// DER encoded ASN.1 blob. In an ideal world, we would be able to call AsymmetricAlgorithm.ImportPkcs8PrivateKey /// off an RSA object to import the private key from a byte array, which we got from the PEM file. There /// are a few issues with this however: /// /// 1. ImportPkcs8PrivateKey does not exist in the Desktop .NET Framework as of today. /// 2. ImportPkcs8PrivateKey was added to .NET Core in 3.0, and we'd love to be able to support this /// on older versions of .NET Core. /// /// This code is able to decode RSA keys (without any attributes) from well formed PKCS#8 blobs. /// </summary> [ExcludeFromCodeCoverage] internal static partial class LightweightPkcs8Decoder { private static readonly byte[] s_derIntegerZero = { 0x02, 0x01, 0x00 }; private static readonly byte[] s_rsaAlgorithmId = { 0x30, 0x0D, 0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01, 0x05, 0x00, }; internal static byte[] ReadBitString(byte[] data, ref int offset) { // Adapted from https://github.com/dotnet/runtime/blob/be74b4bd/src/libraries/System.Formats.Asn1/src/System/Formats/Asn1/AsnDecoder.BitString.cs#L156 if (data[offset++] != 0x03) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int length = ReadLength(data, ref offset); if (length == 0) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int unusedBitCount = data[offset++]; if (unusedBitCount > 7) { throw new InvalidDataException("Invalid PKCS#8 Data"); } Span<byte> span = data.AsSpan(offset, length - 1); // Build a mask for the bits that are used so the normalized value can be computed // // If 3 bits are "unused" then build a mask for them to check for 0. // -1 << 3 => 0b1111_1111 << 3 => 0b1111_1000 int mask = -1 << unusedBitCount; byte lastByte = span[span.Length - 1]; byte maskedByte = (byte)(lastByte & mask); byte[] ret = new byte[span.Length]; Buffer.BlockCopy(data, offset, ret, 0, span.Length); ret[span.Length - 1] = maskedByte; offset += span.Length; return ret; } internal static string ReadObjectIdentifier(byte[] data, ref int offset) { // Adapted from https://github.com/dotnet/runtime/blob/be74b4bd/src/libraries/System.Formats.Asn1/src/System/Formats/Asn1/AsnDecoder.Oid.cs#L175 if (data[offset++] != 0x06) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int length = ReadLength(data, ref offset); StringBuilder ret = new StringBuilder(); for (int i = offset; i < offset + length; i++) { byte val = data[i]; if (i == offset) { byte first; if (val < 40) { first = 0; } else if (val < 80) { first = 1; val -= 40; } else { throw new InvalidDataException("Unsupported PKCS#8 Data"); } ret.Append(first).Append('.').Append(val); } else { if (val < 128) { ret.Append('.').Append(val); } else { ret.Append('.'); if (val == 0x80) { throw new InvalidDataException("Invalid PKCS#8 Data"); } // See how long the segment is. int end = -1; int idx; for (idx = i; idx < offset + length; idx++) { if ((data[idx] & 0x80) == 0) { end = idx; break; } } if (end < 0) { throw new InvalidDataException("Invalid PKCS#8 Data"); } // 4 or fewer bytes fits into a signed integer. int max = end + 1; if (max <= i + 4) { // cspell:ignore accum int accum = 0; for (idx = i; idx < max; idx++) { val = data[idx]; accum <<= 7; accum |= (byte)(val & 0x7f); } ret.Append(accum); i = end; } else { throw new InvalidDataException("Unsupported PKCS#8 Data"); } } } } offset += length; return ret.ToString(); } internal static byte[] ReadOctetString(byte[] data, ref int offset) { if (data[offset++] != 0x04) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int length = ReadLength(data, ref offset); byte[] ret = new byte[length]; Buffer.BlockCopy(data, offset, ret, 0, length); offset += length; return ret; } private static int ReadLength(byte[] data, ref int offset) { byte lengthOrLengthLength = data[offset++]; if (lengthOrLengthLength < 0x80) { return lengthOrLengthLength; } int lengthLength = lengthOrLengthLength & 0x7F; int length = 0; for (int i = 0; i < lengthLength; i++) { length <<= 8; length |= data[offset++]; if (length > ushort.MaxValue) { throw new InvalidDataException("Invalid PKCS#8 Data"); } } return length; } private static byte[] ReadUnsignedInteger(byte[] data, ref int offset, int targetSize = 0) { if (data[offset++] != 0x02) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int length = ReadLength(data, ref offset); // Encoding rules say 0 is encoded as the one byte value 0x00. // Since we expect unsigned, throw if the high bit is set. if (length < 1 || data[offset] >= 0x80) { throw new InvalidDataException("Invalid PKCS#8 Data"); } byte[] ret; if (length == 1) { ret = new byte[length]; ret[0] = data[offset++]; return ret; } if (data[offset] == 0) { offset++; length--; } if (targetSize != 0) { if (length > targetSize) { throw new InvalidDataException("Invalid PKCS#8 Data"); } ret = new byte[targetSize]; } else { ret = new byte[length]; } Buffer.BlockCopy(data, offset, ret, ret.Length - length, length); offset += length; return ret; } private static int ReadPayloadTagLength(byte[] data, ref int offset, byte tagValue) { if (data[offset++] != tagValue) { throw new InvalidDataException("Invalid PKCS#8 Data"); } return ReadLength(data, ref offset); } private static void ConsumeFullPayloadTag(byte[] data, ref int offset, byte tagValue) { if (data[offset++] != tagValue) { throw new InvalidDataException("Invalid PKCS#8 Data"); } int length = ReadLength(data, ref offset); if (data.Length - offset != length) { throw new InvalidDataException("Invalid PKCS#8 Data"); } } private static void ConsumeMatch(byte[] data, ref int offset, byte[] toMatch) { if (data.Length - offset > toMatch.Length) { if (data.Skip(offset).Take(toMatch.Length).SequenceEqual(toMatch)) { offset += toMatch.Length; return; } } throw new InvalidDataException("Invalid PKCS#8 Data"); } public static RSA DecodeRSAPkcs8(byte[] pkcs8Bytes) { int offset = 0; // PrivateKeyInfo SEQUENCE ConsumeFullPayloadTag(pkcs8Bytes, ref offset, 0x30); // PKCS#8 PrivateKeyInfo.version == 0 ConsumeMatch(pkcs8Bytes, ref offset, s_derIntegerZero); // rsaEncryption AlgorithmIdentifier value ConsumeMatch(pkcs8Bytes, ref offset, s_rsaAlgorithmId); // PrivateKeyInfo.privateKey OCTET STRING ConsumeFullPayloadTag(pkcs8Bytes, ref offset, 0x04); // RSAPrivateKey SEQUENCE ConsumeFullPayloadTag(pkcs8Bytes, ref offset, 0x30); // RSAPrivateKey.version == 0 ConsumeMatch(pkcs8Bytes, ref offset, s_derIntegerZero); RSAParameters rsaParameters = new RSAParameters(); rsaParameters.Modulus = ReadUnsignedInteger(pkcs8Bytes, ref offset); rsaParameters.Exponent = ReadUnsignedInteger(pkcs8Bytes, ref offset); rsaParameters.D = ReadUnsignedInteger(pkcs8Bytes, ref offset, rsaParameters.Modulus.Length); int halfModulus = (rsaParameters.Modulus.Length + 1) / 2; rsaParameters.P = ReadUnsignedInteger(pkcs8Bytes, ref offset, halfModulus); rsaParameters.Q = ReadUnsignedInteger(pkcs8Bytes, ref offset, halfModulus); rsaParameters.DP = ReadUnsignedInteger(pkcs8Bytes, ref offset, halfModulus); rsaParameters.DQ = ReadUnsignedInteger(pkcs8Bytes, ref offset, halfModulus); rsaParameters.InverseQ = ReadUnsignedInteger(pkcs8Bytes, ref offset, halfModulus); if (offset != pkcs8Bytes.Length) { throw new InvalidDataException("Invalid PKCS#8 Data"); } RSA rsa = RSA.Create(); rsa.ImportParameters(rsaParameters); return rsa; } public static string DecodePrivateKeyOid(byte[] pkcs8Bytes) { int offset = 0; // PrivateKeyInfo SEQUENCE ConsumeFullPayloadTag(pkcs8Bytes, ref offset, 0x30); // PKCS#8 PrivateKeyInfo.version == 0 ConsumeMatch(pkcs8Bytes, ref offset, s_derIntegerZero); // PKCS#8 PrivateKeyInfo.sequence ReadPayloadTagLength(pkcs8Bytes, ref offset, 0x30); // Return the AlgorithmIdentifier value return ReadObjectIdentifier(pkcs8Bytes, ref offset); } } }