src/Microsoft.Azure.SignalR.Common/Utilities/ConnectionStringParser.cs (188 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text.RegularExpressions;
using Azure.Core;
using Azure.Identity;
namespace Microsoft.Azure.SignalR;
#nullable enable
internal static class ConnectionStringParser
{
private const string AccessKeyProperty = "accesskey";
private const string AuthTypeProperty = "authtype";
private const string ClientCertProperty = "clientCert";
private const string ClientIdProperty = "clientId";
private const string ClientSecretProperty = "clientSecret";
private const string EndpointProperty = "endpoint";
private const string ClientEndpointProperty = "clientEndpoint";
private const string ServerEndpointProperty = "serverEndpoint";
private const string InvalidVersionValueFormat = "Version {0} is not supported.";
private const string PortProperty = "port";
// For SDK 1.x, only support Azure SignalR Service 1.x
private const string SupportedVersion = "1";
private const string TenantIdProperty = "tenantId";
[Obsolete]
private const string TypeAzureAD = "aad";
private const string TypeAzureApp = "azure.app";
private const string TypeAzureMsi = "azure.msi";
private const string ValidVersionRegex = "^" + SupportedVersion + @"\.\d+(?:[\w-.]+)?$";
private const string VersionProperty = "version";
private const string InvalidEndpointProperty = $"Invalid value for {EndpointProperty} property, it must be a valid URI.";
private const string InvalidClientEndpointProperty = $"Invalid value for {ClientEndpointProperty} property, it must be a valid URI.";
private const string InvalidServerEndpointProperty = $"Invalid value for {ServerEndpointProperty} property, it must be a valid URI.";
private const string InvalidPortValue = $"Invalid value for {PortProperty} property, it must be an positive integer between (0, 65536).";
private static readonly char[] KeyValueSeparator = { '=' };
private const string MissingClientIdProperty =
$"Connection string missing required properties {ClientIdProperty}.";
private const string MissingClientSecretProperty =
$"Connection string missing required properties {ClientSecretProperty} or {ClientCertProperty}.";
private const string MissingEndpointProperty =
$"Connection string missing required properties {EndpointProperty}.";
private const string MissingTenantIdProperty =
$"Connection string missing required properties {TenantIdProperty}.";
private static readonly char[] PropertySeparator = { ';' };
internal static ParsedConnectionString Parse(string connectionString)
{
var dict = ToDictionary(connectionString);
// parse and validate endpoint.
if (!dict.TryGetValue(EndpointProperty, out var endpoint))
{
throw new ArgumentException(MissingEndpointProperty, nameof(connectionString));
}
endpoint = endpoint.TrimEnd('/');
if (!TryCreateEndpointUri(endpoint, out var endpointUri))
{
throw new ArgumentException(InvalidEndpointProperty, nameof(connectionString));
}
var builder = new UriBuilder(endpointUri!);
// parse and validate version.
if (dict.TryGetValue(VersionProperty, out var version))
{
if (!Regex.IsMatch(version, ValidVersionRegex))
{
throw new ArgumentException(string.Format(CultureInfo.InvariantCulture, InvalidVersionValueFormat, version), nameof(connectionString));
}
}
// parse and validate port.
if (dict.TryGetValue(PortProperty, out var s))
{
builder.Port = int.TryParse(s, out var port) && port > 0 && port <= 0xFFFF
? port
: throw new ArgumentException(InvalidPortValue, nameof(connectionString));
}
Uri? clientEndpointUri = null;
Uri? serverEndpointUri = null;
// parse and validate clientEndpoint.
if (dict.TryGetValue(ClientEndpointProperty, out var clientEndpoint))
{
if (!TryCreateEndpointUri(clientEndpoint, out clientEndpointUri))
{
throw new ArgumentException(InvalidClientEndpointProperty, nameof(connectionString));
}
}
// parse and validate clientEndpoint.
if (dict.TryGetValue(ServerEndpointProperty, out var serverEndpoint))
{
if (!TryCreateEndpointUri(serverEndpoint, out serverEndpointUri))
{
throw new ArgumentException(InvalidServerEndpointProperty, nameof(connectionString));
}
}
// try building accesskey.
dict.TryGetValue(AuthTypeProperty, out var type);
var tokenCredential = type?.ToLower() switch
{
TypeAzureApp => BuildApplicationCredential(dict),
TypeAzureMsi => BuildManagedIdentityCredential(dict),
#pragma warning disable CS0612 // Type or member is obsolete
TypeAzureAD => BuildAzureTokenCredential(dict),
#pragma warning restore CS0612 // Type or member is obsolete
_ => new DefaultAzureCredential(),
};
dict.TryGetValue(AccessKeyProperty, out var accessKey);
return new ParsedConnectionString(builder.Uri, tokenCredential)
{
AccessKey = accessKey,
ClientEndpoint = clientEndpointUri,
ServerEndpoint = serverEndpointUri
};
}
private static bool TryCreateEndpointUri(string endpoint, out Uri? uriResult)
{
return Uri.TryCreate(endpoint, UriKind.Absolute, out uriResult)
&& (uriResult.Scheme == Uri.UriSchemeHttp || uriResult.Scheme == Uri.UriSchemeHttps);
}
[Obsolete]
private static TokenCredential BuildAzureTokenCredential(Dictionary<string, string> keyValuePairs)
{
if (keyValuePairs.TryGetValue(ClientIdProperty, out var clientId))
{
if (keyValuePairs.TryGetValue(TenantIdProperty, out var tenantId))
{
if (keyValuePairs.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new ClientSecretCredential(tenantId, clientId, clientSecret);
}
else if (keyValuePairs.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new ClientCertificateCredential(tenantId, clientId, clientCertPath);
}
else
{
throw new ArgumentException(MissingClientSecretProperty, nameof(keyValuePairs));
}
}
else
{
return new ManagedIdentityCredential(clientId);
}
}
else
{
return new ManagedIdentityCredential();
}
}
private static TokenCredential BuildApplicationCredential(Dictionary<string, string> keyValuePairs)
{
if (!keyValuePairs.TryGetValue(ClientIdProperty, out var clientId))
{
throw new ArgumentException(MissingClientIdProperty, nameof(keyValuePairs));
}
if (!keyValuePairs.TryGetValue(TenantIdProperty, out var tenantId))
{
throw new ArgumentException(MissingTenantIdProperty, nameof(keyValuePairs));
}
if (keyValuePairs.TryGetValue(ClientSecretProperty, out var clientSecret))
{
return new ClientSecretCredential(tenantId, clientId, clientSecret);
}
else if (keyValuePairs.TryGetValue(ClientCertProperty, out var clientCertPath))
{
return new ClientCertificateCredential(tenantId, clientId, clientCertPath);
}
throw new ArgumentException(MissingClientSecretProperty, nameof(keyValuePairs));
}
private static TokenCredential BuildManagedIdentityCredential(Dictionary<string, string> dict)
{
return dict.TryGetValue(ClientIdProperty, out var clientId)
? new ManagedIdentityCredential(clientId)
: new ManagedIdentityCredential();
}
private static Dictionary<string, string> ToDictionary(string connectionString)
{
var properties = connectionString.Split(PropertySeparator, StringSplitOptions.RemoveEmptyEntries);
if (properties.Length < 2)
{
throw new ArgumentException(MissingEndpointProperty, nameof(connectionString));
}
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var property in properties)
{
var kvp = property.Split(KeyValueSeparator, 2);
if (kvp.Length != 2)
{
continue;
}
var key = kvp[0].Trim();
if (dict.ContainsKey(key))
{
throw new ArgumentException($"Duplicate properties found in connection string: {key}.");
}
dict.Add(key, kvp[1].Trim());
}
return dict;
}
}