src/Microsoft.Azure.SignalR.Protocols/MessagePackUtils.cs (214 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Security.Claims;
using MessagePack;
using Microsoft.Extensions.Primitives;
namespace Microsoft.Azure.SignalR.Protocol;
internal static class MessagePackUtils
{
internal static readonly IDictionary<string, ReadOnlyMemory<byte>> EmptyReadOnlyMemoryDictionary = new Dictionary<string, ReadOnlyMemory<byte>>();
internal static readonly IDictionary<string, StringValues> EmptyStringValuesDictionaryIgnoreCase = new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase);
internal static readonly int ProtocolVersion = 1;
internal static Claim[] ReadClaims(ref MessagePackReader reader)
{
var claimCount = ReadMapLength(ref reader, "claims");
if (claimCount > 0)
{
var claims = new Claim[claimCount];
for (var i = 0; i < claimCount; i++)
{
var type = ReadString(ref reader, "claims[{0}].Type", i);
var value = ReadString(ref reader, "claims[{0}].Value", i);
claims[i] = new Claim(type, value);
}
return claims;
}
return [];
}
internal static IDictionary<string, ReadOnlyMemory<byte>> ReadPayloads(ref MessagePackReader reader)
{
var payloadCount = ReadMapLength(ref reader, "payloads");
if (payloadCount > 0)
{
var payloads = new ArrayDictionary<string, ReadOnlyMemory<byte>>((int)payloadCount, StringComparer.OrdinalIgnoreCase);
for (var i = 0; i < payloadCount; i++)
{
var keyName = $"payloads[{i}].key";
var key = ReadStringNotNull(ref reader, keyName);
var value = ReadBytes(ref reader, "payloads[{0}].value", i);
payloads.Add(key, value);
}
return payloads;
}
return EmptyReadOnlyMemoryDictionary;
}
internal static IDictionary<string, StringValues> ReadHeaders(ref MessagePackReader reader)
{
var headerCount = ReadMapLength(ref reader, "headers");
if (headerCount > 0)
{
var headers = new Dictionary<string, StringValues>((int)headerCount, StringComparer.OrdinalIgnoreCase);
for (var i = 0; i < headerCount; i++)
{
var keyName = $"headers[{i}].key";
var key = ReadStringNotNull(ref reader, keyName);
var count = ReadArrayLength(ref reader, $"headers[{i}].value.length");
var stringValues = new string?[count];
for (var j = 0; j < count; j++)
{
stringValues[j] = ReadString(ref reader, $"headers[{i}].value[{j}]");
}
headers.Add(key, stringValues);
}
return headers;
}
return EmptyStringValuesDictionaryIgnoreCase;
}
internal static bool ReadBoolean(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadBoolean();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
}
}
internal static int ReadInt32(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadInt32();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
}
}
internal static string? ReadString(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadString();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
}
}
internal static string ReadStringNotNull(ref MessagePackReader reader, string field)
{
string? result;
try
{
result = reader.ReadString();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
}
if (result == null)
{
throw new InvalidDataException($"Reading '{field}' as Not-Null String failed.");
}
return result;
}
internal static string? ReadString(ref MessagePackReader reader, string formatField, int param)
{
try
{
return reader.ReadString();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as String failed.", ex);
}
}
internal static string[] ReadStringArrayExcludeNull(ref MessagePackReader reader, string field)
{
var arrayLength = ReadArrayLength(ref reader, field);
if (arrayLength > 0)
{
var array = new string[arrayLength];
var count = 0;
for (int i = 0; i < arrayLength; i++)
{
var fieldName = $"{field}[{i}]";
var val = ReadString(ref reader, fieldName);
if (val != null)
{
array[count] = val;
count++;
}
}
if (arrayLength == count)
{
return array;
}
else
{
return array.Take(count).ToArray();
}
}
return [];
}
internal static byte[] ReadBytes(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadBytes()?.ToArray() ?? Array.Empty<byte>();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as Byte[] failed.", ex);
}
}
internal static byte[] ReadBytes(ref MessagePackReader reader, string formatField, int param)
{
try
{
return reader.ReadBytes()?.ToArray() ?? Array.Empty<byte>();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{string.Format(formatField, param)}' as Byte[] failed.", ex);
}
}
internal static ReadOnlySequence<byte>? ReadByteSequence(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadBytes();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading binary sequence for '{field}' failed.", ex);
}
}
internal static long ReadMapLength(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadMapHeader();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
}
}
internal static int ReadArrayLength(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadArrayHeader();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
}
}
}