provisioning/transport/mqtt/src/ProvisioningChannelHandlerAdapter.cs (487 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using DotNetty.Buffers;
using DotNetty.Codecs.Mqtt.Packets;
using DotNetty.Transport.Channels;
using Microsoft.Azure.Devices.Provisioning.Client.Transport.Models;
using Microsoft.Azure.Devices.Shared;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Net;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Azure.Devices.Provisioning.Client.Transport
{
//
// Note on ConfigureAwait: dotNetty is using a custom TaskScheduler that binds Tasks to the corresponding
// EventLoop. To limit I/O to the EventLoopGroup and keep Netty semantics, we are going to ensure that the
// task continuations are executed by this scheduler using ConfigureAwait(true).
//
internal class ProvisioningChannelHandlerAdapter : ChannelHandlerAdapter
{
private const string ExceptionPrefix = "MQTT Protocol Exception:";
private const QualityOfService Qos = QualityOfService.AtLeastOnce;
private const string UsernameFormat = "{0}/registrations/{1}/api-version={2}&ClientVersion={3}";
private const string SubscribeFilter = "$dps/registrations/res/#";
private const string RegisterTopic = "$dps/registrations/PUT/iotdps-register/?$rid={0}";
private const string GetOperationsTopic = "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={0}&operationId={1}";
private static readonly Regex s_registrationStatusTopicRegex = new Regex("^\\$dps/registrations/res/(.*?)/\\?\\$rid=(.*?)$", RegexOptions.Compiled);
private static readonly TimeSpan s_defaultOperationPoolingInterval = TimeSpan.FromSeconds(2);
private const string Registration = "registration";
private readonly ProvisioningTransportRegisterMessage _message;
private readonly TaskCompletionSource<RegistrationOperationStatus> _taskCompletionSource;
private readonly CancellationToken _cancellationToken;
private int _state;
private int _packetId;
internal enum State
{
Start,
Failed,
WaitForConnack,
WaitForSuback,
WaitForPubAck,
WaitForStatus,
Done,
}
public ProvisioningChannelHandlerAdapter(
ProvisioningTransportRegisterMessage message,
TaskCompletionSource<RegistrationOperationStatus> taskCompletionSource,
CancellationToken cancellationToken)
{
_message = message;
_taskCompletionSource = taskCompletionSource;
_cancellationToken = cancellationToken;
ForceState(State.Start);
}
#region DotNetty.ChannelHandlerAdapter overrides
public override async void ChannelActive(IChannelHandlerContext context)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(ChannelActive));
await VerifyCancellationAsync(context).ConfigureAwait(true);
try
{
ChangeState(State.Start, State.WaitForConnack);
await ConnectAsync(context).ConfigureAwait(true);
}
catch (Exception ex)
{
if (ex is AggregateException)
{
ex = ex.InnerException;
}
await FailWithExceptionAsync(context, ex).ConfigureAwait(true);
}
base.ChannelActive(context);
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(ChannelActive));
}
public override async void ChannelInactive(IChannelHandlerContext context)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(ChannelInactive));
base.ChannelInactive(context);
await FailWithExceptionAsync(
context,
new ProvisioningTransportException($"{ExceptionPrefix} Channel closed."))
.ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(ChannelInactive));
}
public override async void ChannelRead(IChannelHandlerContext context, object message)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(ChannelRead));
Debug.Assert(message is Packet);
await VerifyCancellationAsync(context).ConfigureAwait(true);
await ProcessMessageAsync(context, (Packet)message).ConfigureAwait(true);
base.ChannelRead(context, message);
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(ChannelRead));
}
public override async void ChannelReadComplete(IChannelHandlerContext context)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(ChannelReadComplete));
await VerifyCancellationAsync(context).ConfigureAwait(true);
base.ChannelReadComplete(context);
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(ChannelReadComplete));
}
public override async void ExceptionCaught(IChannelHandlerContext context, Exception exception)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(ExceptionCaught));
base.ExceptionCaught(context, exception);
await FailWithExceptionAsync(context, exception).ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(ExceptionCaught));
}
#endregion DotNetty.ChannelHandlerAdapter overrides
private Task ConnectAsync(IChannelHandlerContext context)
{
string registrationId = _message.Security.GetRegistrationID();
string userAgent = _message.ProductInfo;
string password = null;
if (_message.Security is SecurityProviderSymmetricKey key1)
{
string key = key1.GetPrimaryKey();
password = ProvisioningSasBuilder.BuildSasSignature(Registration, key, string.Concat(_message.IdScope, '/', "registrations", '/', registrationId), TimeSpan.FromDays(1));
}
var message = new ConnectPacket
{
CleanSession = true,
ClientId = registrationId,
HasWill = false,
HasUsername = true,
Username = string.Format(
CultureInfo.InvariantCulture,
UsernameFormat,
_message.IdScope,
registrationId,
ClientApiVersionHelper.ApiVersion,
Uri.EscapeDataString(userAgent)),
HasPassword = password != null,
Password = password,
};
return context.WriteAndFlushAsync(message);
}
private async Task ProcessMessageAsync(IChannelHandlerContext context, Packet message)
{
var currentState = (State)Volatile.Read(ref _state);
switch (currentState)
{
case State.Start:
Debug.Fail($"{nameof(ProvisioningChannelHandlerAdapter)}: Invalid state: {nameof(State.Start)}");
break;
case State.Done:
Debug.Fail($"{nameof(ProvisioningChannelHandlerAdapter)}: Invalid state: {nameof(State.Done)}");
break;
case State.Failed:
Debug.Fail($"{nameof(ProvisioningChannelHandlerAdapter)}: Invalid state: {nameof(State.Failed)}");
break;
case State.WaitForConnack:
await VerifyExpectedPacketTypeAsync(context, PacketType.CONNACK, message).ConfigureAwait(true);
await ProcessConnAckAsync(context, (ConnAckPacket)message).ConfigureAwait(true);
break;
case State.WaitForSuback:
await VerifyExpectedPacketTypeAsync(context, PacketType.SUBACK, message).ConfigureAwait(true);
await ProcessSubAckAsync(context, (SubAckPacket)message).ConfigureAwait(true);
break;
case State.WaitForPubAck:
ChangeState(State.WaitForPubAck, State.WaitForStatus);
await VerifyExpectedPacketTypeAsync(context, PacketType.PUBACK, message).ConfigureAwait(true);
break;
case State.WaitForStatus:
await VerifyExpectedPacketTypeAsync(context, PacketType.PUBLISH, message).ConfigureAwait(true);
await ProcessRegistrationStatusAsync(context, (PublishPacket)message).ConfigureAwait(true);
break;
default:
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} Invalid state: {(State)_state}")).ConfigureAwait(true);
break;
}
}
private async Task ProcessConnAckAsync(IChannelHandlerContext context, ConnAckPacket packet)
{
if (packet.SessionPresent)
{
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} Unexpected CONNACK with SessionPresent."))
.ConfigureAwait(true);
}
switch (packet.ReturnCode)
{
case ConnectReturnCode.Accepted:
try
{
ChangeState(State.WaitForConnack, State.WaitForSuback);
await SubscribeAsync(context).ConfigureAwait(true);
}
catch (Exception ex)
{
if (ex is AggregateException)
{
ex = ex.InnerException;
}
await FailWithExceptionAsync(context, ex).ConfigureAwait(true);
}
break;
case ConnectReturnCode.RefusedUnacceptableProtocolVersion:
case ConnectReturnCode.RefusedIdentifierRejected:
case ConnectReturnCode.RefusedBadUsernameOrPassword:
case ConnectReturnCode.RefusedNotAuthorized:
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} CONNACK failed with {packet.ReturnCode}"))
.ConfigureAwait(true);
break;
case ConnectReturnCode.RefusedServerUnavailable:
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} CONNACK failed with {packet.ReturnCode}. Try again later.",
null,
true))
.ConfigureAwait(true);
break;
default:
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} CONNACK failed unknown return code: {packet.ReturnCode}"))
.ConfigureAwait(true);
break;
}
}
private Task SubscribeAsync(IChannelHandlerContext context)
{
var message = new SubscribePacket(GetNextPacketId(), new SubscriptionRequest(SubscribeFilter, Qos));
return context.WriteAndFlushAsync(message);
}
private async Task ProcessSubAckAsync(IChannelHandlerContext context, SubAckPacket packet)
{
if (packet.PacketId == GetCurrentPacketId())
{
try
{
ChangeState(State.WaitForSuback, State.WaitForPubAck);
await PublishRegisterAsync(context).ConfigureAwait(true);
}
catch (Exception ex)
{
if (ex is AggregateException)
{
ex = ex.InnerException;
}
await FailWithExceptionAsync(context, ex).ConfigureAwait(true);
}
}
}
private async Task PublishRegisterAsync(IChannelHandlerContext context)
{
IByteBuffer packagePayload = Unpooled.Empty;
if (_message.Payload != null && _message.Payload.Length > 0)
{
var deviceRegistration = new DeviceRegistration { Payload = new JRaw(_message.Payload) };
using var customContentStream = new MemoryStream(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(deviceRegistration, JsonSerializerSettingsInitializer.GetJsonSerializerSettings())));
long streamLength = customContentStream.Length;
int length = (int)streamLength;
packagePayload = context.Channel.Allocator.Buffer(length, length);
await packagePayload.WriteBytesAsync(customContentStream, length).ConfigureAwait(false);
}
int packetId = GetNextPacketId();
var message = new PublishPacket(Qos, false, false)
{
TopicName = string.Format(CultureInfo.InvariantCulture, RegisterTopic, packetId),
PacketId = packetId,
Payload = packagePayload,
};
await context.WriteAndFlushAsync(message).ConfigureAwait(false);
}
private async Task VerifyPublishPacketTopicAsync(IChannelHandlerContext context, string topicName, string jsonData)
{
Match match = s_registrationStatusTopicRegex.Match(topicName);
if (match.Groups.Count >= 2)
{
if (Enum.TryParse(match.Groups[1].Value, out HttpStatusCode statusCode))
{
if (statusCode >= HttpStatusCode.BadRequest)
{
ProvisioningErrorDetailsMqtt errorDetails = JsonConvert.DeserializeObject<ProvisioningErrorDetailsMqtt>(jsonData, JsonSerializerSettingsInitializer.GetJsonSerializerSettings());
bool isTransient = statusCode >= HttpStatusCode.InternalServerError || (int)statusCode == 429;
if (isTransient)
{
errorDetails.RetryAfter = ProvisioningErrorDetailsMqtt.GetRetryAfterFromTopic(topicName, s_defaultOperationPoolingInterval);
}
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
jsonData,
null,
isTransient,
errorDetails))
.ConfigureAwait(false);
}
}
}
}
private async Task ProcessRegistrationStatusAsync(IChannelHandlerContext context, PublishPacket packet)
{
try // TODO : extract generic method for exception handling.
{
await PubAckAsync(context, packet.PacketId).ConfigureAwait(true);
string jsonData = Encoding.UTF8.GetString(
packet.Payload.GetIoBuffer().Array,
packet.Payload.GetIoBuffer().Offset,
packet.Payload.GetIoBuffer().Count);
await VerifyPublishPacketTopicAsync(context, packet.TopicName, jsonData).ConfigureAwait(true);
RegistrationOperationStatus operation = JsonConvert.DeserializeObject<RegistrationOperationStatus>(jsonData, JsonSerializerSettingsInitializer.GetJsonSerializerSettings());
string operationId = operation.OperationId;
operation.RetryAfter = ProvisioningErrorDetailsMqtt.GetRetryAfterFromTopic(packet.TopicName, s_defaultOperationPoolingInterval);
if (string.CompareOrdinal(operation.Status, RegistrationOperationStatus.OperationStatusAssigning) == 0 ||
string.CompareOrdinal(operation.Status, RegistrationOperationStatus.OperationStatusUnassigned) == 0)
{
await Task.Delay(operation.RetryAfter ?? RetryJitter.GenerateDelayWithJitterForRetry(s_defaultOperationPoolingInterval)).ConfigureAwait(true);
ChangeState(State.WaitForStatus, State.WaitForPubAck);
await PublishGetOperationAsync(context, operationId).ConfigureAwait(true);
}
else
{
ChangeState(State.WaitForStatus, State.Done);
_taskCompletionSource.TrySetResult(operation);
await DoneAsync(context).ConfigureAwait(true);
}
}
catch (ProvisioningTransportException te)
{
await FailWithExceptionAsync(context, te).ConfigureAwait(true);
}
catch (Exception ex)
{
var wrapperEx = new ProvisioningTransportException(
$"{ExceptionPrefix} Error while processing RegistrationStatus.",
ex,
false);
await FailWithExceptionAsync(context, wrapperEx).ConfigureAwait(true);
}
}
private static Task PubAckAsync(IChannelHandlerContext context, int packetId)
{
var message = new PubAckPacket
{
PacketId = packetId,
};
return context.WriteAndFlushAsync(message);
}
private Task PublishGetOperationAsync(IChannelHandlerContext context, string operationId)
{
int packetId = GetNextPacketId();
var message = new PublishPacket(Qos, false, false)
{
TopicName = string.Format(CultureInfo.InvariantCulture, GetOperationsTopic, packetId, operationId),
PacketId = packetId,
Payload = Unpooled.Empty,
};
return context.WriteAndFlushAsync(message);
}
private async Task VerifyExpectedPacketTypeAsync(IChannelHandlerContext context, PacketType expectedPacketType, Packet message)
{
if (message.PacketType != expectedPacketType)
{
await FailWithExceptionAsync(
context,
new ProvisioningTransportException(
$"{ExceptionPrefix} Received unexpected packet type {message.PacketType} in state {(State)_state}"))
.ConfigureAwait(true);
}
}
private async Task FailWithExceptionAsync(IChannelHandlerContext context, Exception ex)
{
if (Volatile.Read(ref _state) != (int)State.Failed)
{
if (Logging.IsEnabled)
Logging.Error(this, $"Failing with Exception: {ex}", nameof(FailWithExceptionAsync));
ForceState(State.Failed);
_taskCompletionSource.TrySetException(ex);
await context.CloseAsync().ConfigureAwait(true);
}
else
{
if (Logging.IsEnabled)
Logging.Error(this, $"Ignoring Exception: {ex}", nameof(FailWithExceptionAsync));
}
}
private async Task VerifyCancellationAsync(IChannelHandlerContext context)
{
if (_cancellationToken.IsCancellationRequested &&
Volatile.Read(ref _state) != (int)State.Failed)
{
if (Logging.IsEnabled)
Logging.Error(this, "CancellationRequested", nameof(VerifyCancellationAsync));
ForceState(State.Failed);
_taskCompletionSource.TrySetCanceled(_cancellationToken);
await context.CloseAsync().ConfigureAwait(true);
}
}
private void ChangeState(State expectedCurrentState, State newState)
{
if (Logging.IsEnabled)
Logging.Info(this, $"{expectedCurrentState} -> {newState}", nameof(ChangeState));
int currentState = Interlocked.CompareExchange(ref _state, (int)newState, (int)expectedCurrentState);
if (currentState != (int)expectedCurrentState)
{
string newStateString = Enum.GetName(typeof(State), newState);
string currentStateString = Enum.GetName(typeof(State), currentState);
string expectedStateString = Enum.GetName(typeof(State), expectedCurrentState);
var exception = new ProvisioningTransportException(
$"{ExceptionPrefix} Unexpected state transition to {newStateString} from {currentStateString}. " +
$"Expecting {expectedStateString}");
ForceState(State.Failed);
_taskCompletionSource.TrySetException(exception);
}
}
private void ForceState(State newState)
{
if (Logging.IsEnabled)
{
Logging.Info(this, $"{(State)_state} -> {newState}", nameof(ForceState));
}
Volatile.Write(ref _state, (int)newState);
}
private async Task DoneAsync(IChannelHandlerContext context)
{
if (Logging.IsEnabled)
Logging.Enter(this, context.Name, nameof(DoneAsync));
try
{
await context.Channel.WriteAndFlushAsync(DisconnectPacket.Instance).ConfigureAwait(true);
}
catch (Exception e)
{
if (Logging.IsEnabled)
Logging.Info(this, $"Exception trying to send disconnect packet: {e}", nameof(DoneAsync));
await FailWithExceptionAsync(context, e).ConfigureAwait(true);
}
// This delay is required to work-around a .NET Framework CloseAsync bug.
if (Logging.IsEnabled)
Logging.Info(this, "Applying close channel delay.", nameof(DoneAsync));
await Task.Delay(TimeSpan.FromMilliseconds(400)).ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Info(this, "Closing channel.", nameof(DoneAsync));
try
{
await context.Channel.CloseAsync().ConfigureAwait(true);
}
catch (Exception e)
{
if (Logging.IsEnabled)
Logging.Info(this, $"Exception trying to close channel: {e}", nameof(DoneAsync));
await FailWithExceptionAsync(context, e).ConfigureAwait(true);
}
if (Logging.IsEnabled)
Logging.Exit(this, context.Name, nameof(DoneAsync));
}
private ushort GetNextPacketId()
{
unchecked
{
ushort newIdShort;
int newId = Interlocked.Increment(ref _packetId);
newIdShort = (ushort)newId;
return newIdShort == 0 ? GetNextPacketId() : newIdShort;
}
}
private ushort GetCurrentPacketId()
{
unchecked
{
return (ushort)Volatile.Read(ref _packetId);
}
}
}
}