iothub/device/src/Transport/Mqtt/MqttTransportHandler.cs (1,089 lines of code) (raw):
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.WebSockets;
using System.Runtime.ExceptionServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using DotNetty.Buffers;
using DotNetty.Codecs.Mqtt;
using DotNetty.Codecs.Mqtt.Packets;
using DotNetty.Handlers.Logging;
using DotNetty.Handlers.Tls;
using DotNetty.Transport.Bootstrapping;
using DotNetty.Transport.Channels;
using DotNetty.Transport.Channels.Sockets;
using Microsoft.Azure.Devices.Client.Exceptions;
using Microsoft.Azure.Devices.Client.Extensions;
using Microsoft.Azure.Devices.Client.TransientFaultHandling;
using Microsoft.Azure.Devices.Shared;
using Newtonsoft.Json;
#if NET5_0_OR_GREATER
using TaskCompletionSource = System.Threading.Tasks.TaskCompletionSource;
#else
using TaskCompletionSource = Microsoft.Azure.Devices.Shared.TaskCompletionSource;
#endif
namespace Microsoft.Azure.Devices.Client.Transport.Mqtt
{
//
// Note on ConfigureAwait: dotNetty is using a custom TaskScheduler that binds Tasks to the corresponding
// EventLoop. To limit I/O to the EventLoopGroup and keep Netty semantics, we are going to ensure that the
// task continuations are executed by this scheduler using ConfigureAwait(true).
//
// All awaited calls that happen within dotnetty's pipeline should be ConfigureAwait(true).
//
internal sealed class MqttTransportHandler : TransportHandler, IMqttIotHubEventHandler
{
private const int ProtocolGatewayPort = 8883;
private const int MaxMessageSize = 256 * 1024;
private const string ProcessorThreadCountVariableName = "MqttEventsProcessorThreadCount";
// Topic names for receiving cloud-to-device messages.
private const string DeviceBoundMessagesTopicFilter = "devices/{0}/messages/devicebound/#";
private const string DeviceBoundMessagesTopicPrefix = "devices/{0}/messages/devicebound/";
// Topic names for retrieving a device's twin properties.
// The client first subscribes to "$iothub/twin/res/#", to receive the operation's responses.
// It then sends an empty message to the topic "$iothub/twin/GET/?$rid={request id}, with a populated value for request Id.
// The service then sends a response message containing the device twin data on topic "$iothub/twin/res/{status}/?$rid={request id}", using the same request Id as the request.
private const string TwinResponseTopicFilter = "$iothub/twin/res/#";
private const string TwinResponseTopicPrefix = "$iothub/twin/res/";
private const string TwinGetTopic = "$iothub/twin/GET/?$rid={0}";
private const string TwinResponseTopicPattern = @"\$iothub/twin/res/(\d+)/(\?.+)";
// Topic name for updating device twin's reported properties.
// The client first subscribes to "$iothub/twin/res/#", to receive the operation's responses.
// The client then sends a message containing the twin update to "$iothub/twin/PATCH/properties/reported/?$rid={request id}", with a populated value for request Id.
// The service then sends a response message containing the new ETag value for the reported properties collection on the topic "$iothub/twin/res/{status}/?$rid={request id}", using the same request Id as the request.
private const string TwinPatchTopic = "$iothub/twin/PATCH/properties/reported/?$rid={0}";
// Topic names for receiving twin desired property update notifications.
private const string TwinPatchTopicFilter = "$iothub/twin/PATCH/properties/desired/#";
private const string TwinPatchTopicPrefix = "$iothub/twin/PATCH/properties/desired/";
// Topic name for responding to direct methods.
// The client first subscribes to "$iothub/methods/POST/#".
// The service sends method requests to the topic "$iothub/methods/POST/{method name}/?$rid={request id}".
// The client responds to the direct method invocation by sending a message to the topic "$iothub/methods/res/{status}/?$rid={request id}", using the same request Id as the request.
private const string MethodPostTopicFilter = "$iothub/methods/POST/#";
private const string MethodPostTopicPrefix = "$iothub/methods/POST/";
private const string MethodResponseTopic = "$iothub/methods/res/{0}/?$rid={1}";
// Topic names for enabling events on Modules.
private const string ReceiveEventMessagePatternFilter = "devices/{0}/modules/{1}/#";
private const string ReceiveEventMessagePrefixPattern = "devices/{0}/modules/{1}/";
private static readonly int s_generationPrefixLength = Guid.NewGuid().ToString().Length;
private static readonly Lazy<IEventLoopGroup> s_eventLoopGroup = new Lazy<IEventLoopGroup>(GetEventLoopGroup);
private static readonly TimeSpan s_regexTimeoutMilliseconds = TimeSpan.FromMilliseconds(500);
private static readonly TimeSpan s_defaultTwinTimeout = TimeSpan.FromSeconds(60);
private readonly string _generationId = Guid.NewGuid().ToString();
private readonly string _receiveEventMessageFilter;
private readonly string _receiveEventMessagePrefix;
private readonly string _deviceboundMessageFilter;
private readonly string _deviceboundMessagePrefix;
private readonly string _hostName;
private readonly Func<IPAddress[], int, Task<IChannel>> _channelFactory;
private readonly ConcurrentQueue<string> _completionQueue;
private readonly MqttIotHubAdapterFactory _mqttIotHubAdapterFactory;
private readonly QualityOfService _qosSendPacketToService;
private readonly QualityOfService _qosReceivePacketFromService;
private readonly bool _retainMessagesAcrossSessions;
private readonly RetryPolicy _closeRetryPolicy;
private readonly ConcurrentQueue<Message> _messageQueue;
private readonly TaskCompletionSource _connectCompletion = new TaskCompletionSource();
private readonly TaskCompletionSource _subscribeCompletionSource = new TaskCompletionSource();
private readonly IWebProxy _webProxy;
private SemaphoreSlim _deviceReceiveMessageSemaphore = new SemaphoreSlim(1, 1);
private SemaphoreSlim _receivingSemaphore = new SemaphoreSlim(0);
private CancellationTokenSource _disconnectAwaitersCancellationSource = new CancellationTokenSource();
private readonly Regex _twinResponseTopicRegex = new Regex(TwinResponseTopicPattern, RegexOptions.Compiled, s_regexTimeoutMilliseconds);
private readonly Func<MethodRequestInternal, Task> _methodListener;
private readonly Action<TwinCollection> _onDesiredStatePatchListener;
private readonly Func<string, Message, Task> _moduleMessageReceivedListener;
private readonly Func<Message, Task> _deviceMessageReceivedListener;
private bool _isDeviceReceiveMessageCallbackSet;
private Func<Task> _cleanupFunc;
private IChannel _channel;
private ExceptionDispatchInfo _fatalException;
private IPAddress[] _serverAddresses;
private int _state = (int)TransportState.NotInitialized;
private Action<Message> _twinResponseEvent;
internal MqttTransportHandler(
PipelineContext context,
IotHubConnectionString iotHubConnectionString,
MqttTransportSettings settings,
Func<MethodRequestInternal, Task> onMethodCallback = null,
Action<TwinCollection> onDesiredStatePatchReceivedCallback = null,
Func<string, Message, Task> onModuleMessageReceivedCallback = null,
Func<Message, Task> onDeviceMessageReceivedCallback = null)
: this(context, iotHubConnectionString, settings, null)
{
_methodListener = onMethodCallback;
_deviceMessageReceivedListener = onDeviceMessageReceivedCallback;
_moduleMessageReceivedListener = onModuleMessageReceivedCallback;
_onDesiredStatePatchListener = onDesiredStatePatchReceivedCallback;
}
internal MqttTransportHandler(
PipelineContext context,
IotHubConnectionString iotHubConnectionString,
MqttTransportSettings settings,
Func<IPAddress[], int, Task<IChannel>> channelFactory)
: base(context, settings)
{
_mqttIotHubAdapterFactory = new MqttIotHubAdapterFactory(settings);
_messageQueue = new ConcurrentQueue<Message>();
_completionQueue = new ConcurrentQueue<string>();
_serverAddresses = null; // this will be resolved asynchronously in OpenAsync
_hostName = iotHubConnectionString.HostName;
_receiveEventMessageFilter = string.Format(CultureInfo.InvariantCulture, ReceiveEventMessagePatternFilter, iotHubConnectionString.DeviceId, iotHubConnectionString.ModuleId);
_receiveEventMessagePrefix = string.Format(CultureInfo.InvariantCulture, ReceiveEventMessagePrefixPattern, iotHubConnectionString.DeviceId, iotHubConnectionString.ModuleId);
_deviceboundMessageFilter = string.Format(CultureInfo.InvariantCulture, DeviceBoundMessagesTopicFilter, iotHubConnectionString.DeviceId);
_deviceboundMessagePrefix = string.Format(CultureInfo.InvariantCulture, DeviceBoundMessagesTopicPrefix, iotHubConnectionString.DeviceId);
_qosSendPacketToService = settings.PublishToServerQoS;
_qosReceivePacketFromService = settings.ReceivingQoS;
// If the CleanSession flag is set to false, C2D messages will be retained across device sessions, i.e. the device
// will receive the C2D messages that were sent to it while it was disconnected.
// If the CleanSession flag is set to true, the device will receive only those C2D messages that were sent
// after it had subscribed to the message topic.
_retainMessagesAcrossSessions = !settings.CleanSession;
_webProxy = settings.Proxy;
if (channelFactory != null)
{
_channelFactory = channelFactory;
}
else
{
ClientOptions options = context.ClientOptions;
_channelFactory = settings.GetTransportType() switch
{
TransportType.Mqtt_Tcp_Only => CreateChannelFactory(iotHubConnectionString, settings, context.ProductInfo, options),
TransportType.Mqtt_WebSocket_Only => CreateWebSocketChannelFactory(iotHubConnectionString, settings, context.ProductInfo, options),
_ => throw new InvalidOperationException("Unsupported Transport Setting {0}".FormatInvariant(settings.GetTransportType())),
};
}
_closeRetryPolicy = new RetryPolicy(new TransientErrorIgnoreStrategy(), 5, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
public TransportState State => (TransportState)Volatile.Read(ref _state);
public override bool IsUsable => State != TransportState.Closed && State != TransportState.Error;
public TimeSpan TwinTimeout { get; set; } = s_defaultTwinTimeout;
#region Client operations
public override async Task OpenAsync(TimeoutHelper timeoutHelper)
{
using var cts = new CancellationTokenSource(timeoutHelper.GetRemainingTime());
await OpenAsync(cts.Token).ConfigureAwait(false);
}
public override async Task OpenAsync(CancellationToken cancellationToken)
{
try
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(OpenAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState(throwIfNotOpen: false);
await OpenInternalAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(OpenAsync));
}
}
public override async Task SendEventAsync(Message message, CancellationToken cancellationToken)
{
try
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(SendEventAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
Debug.Assert(_channel != null);
await _channel.WriteAndFlushAsync(message).ConfigureAwait(true);
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(SendEventAsync));
}
}
public override async Task SendEventAsync(IEnumerable<Message> messages, CancellationToken cancellationToken)
{
foreach (Message message in messages)
{
cancellationToken.ThrowIfCancellationRequested();
await SendEventAsync(message, cancellationToken).ConfigureAwait(false);
}
}
public override async Task<Message> ReceiveAsync(CancellationToken cancellationToken)
{
if (_isDeviceReceiveMessageCallbackSet)
{
if (Logging.IsEnabled)
Logging.Error(this, "Callback handler set for receiving C2D messages; ReceiveAsync() will now always return null", nameof(ReceiveAsync));
return null;
}
try
{
if (Logging.IsEnabled)
Logging.Enter(
this,
cancellationToken, $"ReceiveAsync() called with cancellation requested state of: {cancellationToken.IsCancellationRequested}",
$"{nameof(ReceiveAsync)}");
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
if (State != TransportState.Receiving)
{
await SubscribeCloudToDeviceMessagesAsync().ConfigureAwait(false);
}
await WaitUntilC2dMessageArrivesAsync(cancellationToken).ConfigureAwait(false);
return ProcessC2dMessage();
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(
this,
cancellationToken,
$"Exiting ReceiveAsync() with cancellation requested state of: {cancellationToken.IsCancellationRequested}",
$"{nameof(ReceiveAsync)}");
}
}
public override async Task<Message> ReceiveAsync(TimeoutHelper timeoutHelper)
{
if (_isDeviceReceiveMessageCallbackSet)
{
if (Logging.IsEnabled)
Logging.Error(this, "Callback handler set for receiving C2D messages; ReceiveAsync() will now always return null", nameof(ReceiveAsync));
return null;
}
try
{
if (Logging.IsEnabled)
Logging.Enter(this, timeoutHelper, $"Time remaining for ReceiveAsync(): {timeoutHelper.GetRemainingTime()}", $"{nameof(ReceiveAsync)}");
EnsureValidState();
if (State != TransportState.Receiving)
{
await SubscribeCloudToDeviceMessagesAsync().ConfigureAwait(false);
}
TimeSpan timeout = timeoutHelper.GetRemainingTime();
using var cts = new CancellationTokenSource(timeout);
await WaitUntilC2dMessageArrivesAsync(cts.Token).ConfigureAwait(false);
return ProcessC2dMessage();
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, timeoutHelper, $"Time remaining for ReceiveAsync(): {timeoutHelper.GetRemainingTime()}", $"{nameof(ReceiveAsync)}");
}
}
private Message ProcessC2dMessage()
{
Message message = null;
try
{
if (Logging.IsEnabled)
Logging.Enter(this, message, $"Will begin processing received C2D message, queue size={_messageQueue.Count}", nameof(ProcessC2dMessage));
if (_messageQueue.TryDequeue(out message))
{
if (_qosReceivePacketFromService == QualityOfService.AtLeastOnce)
{
_completionQueue.Enqueue(message.LockToken);
}
message.LockToken = _generationId + message.LockToken;
}
return message;
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, message, $"Processed received C2D message with Id={message?.MessageId}", nameof(ProcessC2dMessage));
}
}
private async Task WaitUntilC2dMessageArrivesAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
CancellationToken disconnectToken = _disconnectAwaitersCancellationSource.Token;
EnsureValidState();
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, disconnectToken);
// Wait until either of the linked cancellation tokens have been canceled.
await _receivingSemaphore.WaitAsync(linkedCts.Token).ConfigureAwait(false);
}
public override async Task CompleteAsync(string lockToken, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
Logging.Enter(this, $"Completing a message with lockToken: {lockToken}", nameof(CompleteAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
if (_qosReceivePacketFromService == QualityOfService.AtMostOnce)
{
throw new IotHubException("Complete is not allowed for QoS 0.", isTransient: false);
}
if (!lockToken.StartsWith(_generationId, StringComparison.InvariantCulture))
{
throw new IotHubException(
"Lock token is stale or never existed. The message will be redelivered. Please discard this lock token and do not retry the operation.",
isTransient: false);
}
if (_completionQueue.IsEmpty)
{
throw new IotHubException("Unknown lock token.", isTransient: false);
}
_completionQueue.TryDequeue(out string actualLockToken);
if (lockToken.IndexOf(actualLockToken, s_generationPrefixLength, StringComparison.Ordinal) != s_generationPrefixLength ||
lockToken.Length != actualLockToken.Length + s_generationPrefixLength)
{
throw new IotHubException(
$"Client must send PUBACK packets in the order in which the corresponding PUBLISH packets were received (QoS 1 messages) per [MQTT-4.6.0-2]. Expected lock token to end with: '{actualLockToken}'; actual lock token: '{lockToken}'.",
isTransient: false);
}
await _channel.WriteAndFlushAsync(actualLockToken).ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Exit(this, $"Completing a message with lockToken: {lockToken}", nameof(CompleteAsync));
}
public override Task AbandonAsync(string lockToken, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
throw new NotSupportedException("MQTT protocol does not support this operation");
}
public override Task RejectAsync(string lockToken, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
throw new NotSupportedException("MQTT protocol does not support this operation");
}
protected override void Dispose(bool disposing)
{
try
{
if (Logging.IsEnabled)
Logging.Enter(
this,
$"{nameof(DefaultDelegatingHandler)}.Disposed={_isDisposed}; disposing={disposing}",
$"{nameof(MqttTransportHandler)}.{nameof(Dispose)}");
if (!_isDisposed)
{
base.Dispose(disposing);
if (disposing)
{
if (TryStop())
{
CleanUpAsync().GetAwaiter().GetResult();
}
// Log the task completion source tasks' exceptions and avoid unobserved exceptions.
if (_connectCompletion.Task.Exception != null)
{
if (Logging.IsEnabled)
Logging.Error(this, $"{_connectCompletion} has exception {_connectCompletion.Task.Exception}", nameof(Dispose));
}
if (_subscribeCompletionSource.Task.Exception != null)
{
if (Logging.IsEnabled)
Logging.Error(this, $"{_subscribeCompletionSource} has exception {_subscribeCompletionSource.Task.Exception}", nameof(Dispose));
}
_disconnectAwaitersCancellationSource?.Dispose();
_disconnectAwaitersCancellationSource = null;
_receivingSemaphore?.Dispose();
_receivingSemaphore = null;
_deviceReceiveMessageSemaphore?.Dispose();
_deviceReceiveMessageSemaphore = null;
if (_channel is IDisposable disposableChannel)
{
disposableChannel.Dispose();
_channel = null;
}
}
// the _disposed flag is inherited from the base class DefaultDelegatingHandler and is finally set to null there.
}
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(
this,
$"{nameof(DefaultDelegatingHandler)}.Disposed={_isDisposed}; disposing={disposing}",
$"{nameof(MqttTransportHandler)}.{nameof(Dispose)}");
}
}
public override async Task CloseAsync(CancellationToken cancellationToken)
{
try
{
if (Logging.IsEnabled)
Logging.Enter(this, "", $"{nameof(MqttTransportHandler)}.{nameof(CloseAsync)}");
cancellationToken.ThrowIfCancellationRequested();
if (TryStop())
{
OnTransportClosedGracefully();
await _closeRetryPolicy.RunWithRetryAsync(CleanUpImplAsync, cancellationToken).ConfigureAwait(true);
}
else if (State == TransportState.Error)
{
_fatalException.Throw();
}
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, "", $"{nameof(MqttTransportHandler)}.{nameof(CloseAsync)}");
}
}
#endregion Client operations
#region MQTT callbacks
public void OnConnected()
{
if (TryStateTransition(TransportState.Opening, TransportState.Open))
{
_connectCompletion.TrySetResult();
}
}
private async Task HandleIncomingTwinPatchAsync(Message message)
{
try
{
if (_onDesiredStatePatchListener != null)
{
using var reader = new StreamReader(message.GetBodyStream(), System.Text.Encoding.UTF8);
string patch = reader.ReadToEnd();
TwinCollection props = JsonConvert.DeserializeObject<TwinCollection>(patch, JsonSerializerSettingsInitializer.GetJsonSerializerSettings());
await Task.Run(() => _onDesiredStatePatchListener(props)).ConfigureAwait(false);
// Messages with QoS = 1 need to be Acknowledged otherwise it results in mismatched Ack to IoT Hub
// causing next message being replayed and all subsequent messages being queued.
await CompleteIncomingMessageAsync(message).ConfigureAwait(false);
}
}
finally
{
message.Dispose();
}
}
private async Task HandleIncomingMethodPostAsync(Message message)
{
try
{
string[] tokens = Regex.Split(message.MqttTopicName, "/", RegexOptions.Compiled, s_regexTimeoutMilliseconds);
using var mr = new MethodRequestInternal(tokens[3], tokens[4].Substring(6), message.GetBodyStream(), CancellationToken.None);
await Task.Run(() => _methodListener(mr)).ConfigureAwait(false);
// Method with QoS = 1 need to be Acknowledged otherwise it results in mismatched Ack to IoT Hub
// causing next message being replayed and all subsequent messages being queued.
await CompleteIncomingMessageAsync(message).ConfigureAwait(false);
}
finally
{
message.Dispose();
}
}
[SuppressMessage(
"Reliability",
"CA2000:Dispose objects before losing scope",
Justification = "The created message is handed to the user and the user application is in charge of disposing the message.")]
private async Task HandleIncomingMessagesAsync()
{
if (Logging.IsEnabled)
Logging.Enter(this, "Process C2D message via callback", nameof(HandleIncomingMessagesAsync));
Message message = ProcessC2dMessage();
// We are intentionally not awaiting _deviceMessageReceivedListener callback.
// This is a user-supplied callback that isn't required to be awaited by us. We can simply invoke it and continue.
_ = _deviceMessageReceivedListener?.Invoke(message);
await TaskHelpers.CompletedTask.ConfigureAwait(false);
if (Logging.IsEnabled)
Logging.Exit(this, "Process C2D message via callback", nameof(HandleIncomingMessagesAsync));
}
public async void OnMessageReceived(Message message)
{
if (Logging.IsEnabled)
Logging.Enter(this, message, nameof(OnMessageReceived));
// Added Try-Catch to avoid unknown thread exception
// after running for more than 24 hours
try
{
if ((State & TransportState.Open) == TransportState.Open)
{
string topic = message.MqttTopicName;
if (Logging.IsEnabled)
Logging.Info(this, $"Received a message on topic: {topic}", nameof(OnMessageReceived));
if (topic.StartsWith(TwinResponseTopicPrefix, StringComparison.OrdinalIgnoreCase))
{
_twinResponseEvent(message);
await CompleteIncomingMessageAsync(message).ConfigureAwait(false);
}
else if (topic.StartsWith(TwinPatchTopicPrefix, StringComparison.OrdinalIgnoreCase))
{
await HandleIncomingTwinPatchAsync(message).ConfigureAwait(false);
}
else if (topic.StartsWith(MethodPostTopicPrefix, StringComparison.OrdinalIgnoreCase))
{
await HandleIncomingMethodPostAsync(message).ConfigureAwait(false);
}
else if (topic.StartsWith(_receiveEventMessagePrefix, StringComparison.OrdinalIgnoreCase))
{
await HandleIncomingEventMessageAsync(message).ConfigureAwait(false);
}
else if (topic.StartsWith(_deviceboundMessagePrefix, StringComparison.OrdinalIgnoreCase))
{
_messageQueue.Enqueue(message);
if (_isDeviceReceiveMessageCallbackSet)
{
await HandleIncomingMessagesAsync().ConfigureAwait(false);
}
else
{
_receivingSemaphore.Release();
}
}
else
{
if (Logging.IsEnabled)
Logging.Error(this, $"Received MQTT message on an unrecognized topic, ignoring message. Topic: {topic}");
}
}
}
catch (Exception ex)
{
if (Logging.IsEnabled)
Logging.Error(this, $"Received an exception while processing an MQTT message: {ex}", nameof(OnMessageReceived));
OnError(ex);
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, message, nameof(OnMessageReceived));
}
}
private async Task CompleteIncomingMessageAsync(Message message)
{
try
{
if (_qosReceivePacketFromService == QualityOfService.AtLeastOnce && message.LockToken != null)
{
_completionQueue.Enqueue(message.LockToken);
await CompleteAsync(_generationId + message.LockToken, CancellationToken.None).ConfigureAwait(false);
}
}
catch (Exception ex) when (!ex.IsFatal())
{
OnError(ex);
throw;
}
}
private async Task HandleIncomingEventMessageAsync(Message message)
{
try
{
// The MqttTopic is in the format - devices/deviceId/modules/moduleId/inputs/inputName
// We try to get the endpoint from the topic, if the topic is in the above format.
string[] tokens = message.MqttTopicName.Split('/');
string inputName = tokens.Length >= 6 ? tokens[5] : null;
// Add the endpoint as a SystemProperty
message.SystemProperties.Add(MessageSystemPropertyNames.InputName, inputName);
if (_qosReceivePacketFromService == QualityOfService.AtLeastOnce)
{
_completionQueue.Enqueue(message.LockToken);
}
message.LockToken = _generationId + message.LockToken;
await (_moduleMessageReceivedListener?.Invoke(inputName, message) ?? TaskHelpers.CompletedTask).ConfigureAwait(false);
}
finally
{
message.Dispose();
}
}
public async void OnError(Exception exception)
{
if (Logging.IsEnabled)
Logging.Enter(this, exception, nameof(OnError));
try
{
TransportState previousState = MoveToStateIfPossible(TransportState.Error, TransportState.Closed);
switch (previousState)
{
case TransportState.Error:
case TransportState.Closed:
return;
case TransportState.NotInitialized:
case TransportState.Opening:
_fatalException = ExceptionDispatchInfo.Capture(exception);
_connectCompletion.TrySetException(exception);
_subscribeCompletionSource.TrySetException(exception);
break;
case TransportState.Open:
case TransportState.Subscribing:
_fatalException = ExceptionDispatchInfo.Capture(exception);
_subscribeCompletionSource.TrySetException(exception);
OnTransportDisconnected();
break;
case TransportState.Receiving:
_fatalException = ExceptionDispatchInfo.Capture(exception);
_disconnectAwaitersCancellationSource.Cancel();
OnTransportDisconnected();
break;
default:
string error = $"Unknown transport state: {previousState}";
Debug.Fail(error);
throw new InvalidOperationException(error);
}
await _closeRetryPolicy.RunWithRetryAsync(CleanUpImplAsync).ConfigureAwait(true);
}
catch (Exception ex) when (!ex.IsFatal())
{
if (Logging.IsEnabled)
Logging.Error(this, ex.ToString(), nameof(OnError));
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, exception, nameof(OnError));
}
}
private TransportState MoveToStateIfPossible(TransportState destination, TransportState illegalStates)
{
TransportState previousState = State;
do
{
if ((previousState & illegalStates) > 0)
{
return previousState;
}
TransportState prevState;
if ((prevState = (TransportState)Interlocked.CompareExchange(ref _state, (int)destination, (int)previousState)) == previousState)
{
return prevState;
}
previousState = prevState;
} while (true);
}
#endregion MQTT callbacks
public override async Task EnableReceiveMessageAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(EnableReceiveMessageAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
try
{
// Wait to grab the semaphore, and then enable C2D message subscription and set _isDeviceReceiveMessageCallbackSet to true.
// Once _isDeviceReceiveMessageCallbackSet is set to true, all received C2D messages will be returned on the callback,
// and not via the polling ReceiveAsync() call.
await _deviceReceiveMessageSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
if (State != TransportState.Receiving)
{
await SubscribeCloudToDeviceMessagesAsync().ConfigureAwait(false);
}
_isDeviceReceiveMessageCallbackSet = true;
}
finally
{
_deviceReceiveMessageSemaphore.Release();
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(EnableReceiveMessageAsync));
}
}
public override async Task EnsurePendingMessagesAreDeliveredAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
// If the device connects with a CleanSession flag set to false, we will need to deliver the messages
// that were sent before the client had subscribed to the C2D message receive topic.
if (_retainMessagesAcrossSessions)
{
// Received C2D messages are enqueued into _messageQueue.
while (!_messageQueue.IsEmpty)
{
await HandleIncomingMessagesAsync().ConfigureAwait(false);
}
}
}
public override async Task DisableReceiveMessageAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(DisableReceiveMessageAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
try
{
// Wait to grab the semaphore, and then unsubscribe from C2D messages and set _isDeviceReceiveMessageCallbackSet to false.
// Once _isDeviceReceiveMessageCallbackSet is set to false, all received C2D messages can be returned via the polling ReceiveAsync() call.
await _deviceReceiveMessageSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
// The TransportState is transitioned to Receiving only if the device is subscribed to _deviceboundMessageFilter.
// Only if the subscription has been previously set, we will send the unsubscribe packet.
if (State == TransportState.Receiving
&& TryStateTransition(TransportState.Receiving, TransportState.Open))
{
// Update the TransportState from Receiving to Open.
await _channel.WriteAsync(new UnsubscribePacket(0, _deviceboundMessageFilter)).ConfigureAwait(true);
}
_isDeviceReceiveMessageCallbackSet = false;
}
finally
{
_deviceReceiveMessageSemaphore.Release();
}
}
finally
{
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(DisableReceiveMessageAsync));
}
}
public override async Task EnableMethodsAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new SubscribePacket(0, new SubscriptionRequest(MethodPostTopicFilter, _qosReceivePacketFromService))).ConfigureAwait(true);
}
public override async Task DisableMethodsAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new UnsubscribePacket(0, MethodPostTopicFilter)).ConfigureAwait(true);
}
public override async Task EnableEventReceiveAsync(bool isAnEdgeModule, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new SubscribePacket(0, new SubscriptionRequest(_receiveEventMessageFilter, _qosReceivePacketFromService))).ConfigureAwait(true);
}
public override async Task DisableEventReceiveAsync(bool isAnEdgeModule, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new UnsubscribePacket(0, _receiveEventMessageFilter)).ConfigureAwait(true);
}
public override async Task SendMethodResponseAsync(MethodResponseInternal methodResponse, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
using var message = new Message(methodResponse.BodyStream)
{
MqttTopicName = MethodResponseTopic.FormatInvariant(methodResponse.Status, methodResponse.RequestId)
};
await SendEventAsync(message, cancellationToken).ConfigureAwait(false);
}
public override async Task EnableTwinPatchAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(EnableTwinPatchAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new SubscribePacket(0, new SubscriptionRequest(TwinPatchTopicFilter, _qosReceivePacketFromService))).ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(EnableTwinPatchAsync));
}
public override async Task DisableTwinPatchAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
Logging.Enter(this, cancellationToken, nameof(DisableTwinPatchAsync));
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
await _channel.WriteAsync(new UnsubscribePacket(0, TwinPatchTopicFilter)).ConfigureAwait(true);
if (Logging.IsEnabled)
Logging.Exit(this, cancellationToken, nameof(DisableTwinPatchAsync));
}
public override async Task<Twin> SendTwinGetAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
using var request = new Message();
string rid = Guid.NewGuid().ToString();
request.MqttTopicName = TwinGetTopic.FormatInvariant(rid);
using Message response = await SendTwinRequestAsync(request, rid, cancellationToken).ConfigureAwait(false);
using var reader = new StreamReader(response.GetBodyStream(), System.Text.Encoding.UTF8);
string body = reader.ReadToEnd();
try
{
return new Twin
{
Properties = JsonConvert.DeserializeObject<TwinProperties>(body, JsonSerializerSettingsInitializer.GetJsonSerializerSettings()),
};
}
catch (JsonReaderException ex) when (Logging.IsEnabled)
{
Logging.Error(this, $"Failed to parse Twin JSON: {ex}. Message body: '{body}'");
throw;
}
}
public override async Task SendTwinPatchAsync(TwinCollection reportedProperties, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
EnsureValidState();
string body = JsonConvert.SerializeObject(reportedProperties, JsonSerializerSettingsInitializer.GetJsonSerializerSettings());
using var bodyStream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(body));
using var request = new Message(bodyStream);
string rid = Guid.NewGuid().ToString();
request.MqttTopicName = TwinPatchTopic.FormatInvariant(rid);
await SendTwinRequestAsync(request, rid, cancellationToken).ConfigureAwait(false);
}
private async Task OpenInternalAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
if (IsProxyConfigured())
{
// No need to do a DNS lookup since we have the proxy address already
#if NET451
_serverAddresses = new IPAddress[0];
#else
_serverAddresses = Array.Empty<IPAddress>();
#endif
}
else
{
#if NET451
_serverAddresses = Dns.GetHostEntry(_hostName).AddressList;
#elif NET6_0_OR_GREATER
_serverAddresses = await Dns.GetHostAddressesAsync(_hostName, cancellationToken).ConfigureAwait(false);
#else
_serverAddresses = await Dns.GetHostAddressesAsync(_hostName).ConfigureAwait(false);
#endif
}
if (TryStateTransition(TransportState.NotInitialized, TransportState.Opening))
{
try
{
if (_channel != null
&& _channel is IDisposable disposableChannel)
{
if (Logging.IsEnabled)
Logging.Info(this, $"_channel is disposable; disposing", nameof(OpenInternalAsync));
disposableChannel.Dispose();
_channel = null;
}
_channel = await _channelFactory(_serverAddresses, ProtocolGatewayPort).ConfigureAwait(true);
}
catch (Exception ex) when (!ex.IsFatal())
{
OnError(ex);
throw;
}
ScheduleCleanup(async () =>
{
_disconnectAwaitersCancellationSource?.Cancel();
if (_channel == null)
{
return;
}
if (_channel.Active)
{
await _channel.WriteAsync(DisconnectPacket.Instance).ConfigureAwait(true);
}
if (_channel.Open)
{
await _channel.CloseAsync().ConfigureAwait(true);
}
});
}
await _connectCompletion.Task.ConfigureAwait(false);
await SubscribeTwinResponsesAsync().ConfigureAwait(true);
}
private bool TryStop()
{
TransportState previousState = MoveToStateIfPossible(TransportState.Closed, TransportState.Error);
switch (previousState)
{
case TransportState.Closed:
case TransportState.Error:
return false;
case TransportState.NotInitialized:
case TransportState.Opening:
_connectCompletion.TrySetCanceled();
break;
case TransportState.Open:
case TransportState.Subscribing:
_subscribeCompletionSource.TrySetCanceled();
break;
case TransportState.Receiving:
_disconnectAwaitersCancellationSource.Cancel();
break;
default:
Debug.Fail($"Unknown transport state: {previousState}");
throw new InvalidOperationException();
}
return true;
}
private async Task SubscribeCloudToDeviceMessagesAsync()
{
if (TryStateTransition(TransportState.Open, TransportState.Subscribing))
{
await _channel
.WriteAsync(new SubscribePacket(0, new SubscriptionRequest(_deviceboundMessageFilter, _qosReceivePacketFromService)))
.ConfigureAwait(true);
if (TryStateTransition(TransportState.Subscribing, TransportState.Receiving)
&& _subscribeCompletionSource.TrySetResult())
{
return;
}
}
await _subscribeCompletionSource.Task.ConfigureAwait(false);
}
private Task SubscribeTwinResponsesAsync()
{
return _channel.WriteAsync(
new SubscribePacket(
0,
new SubscriptionRequest(
TwinResponseTopicFilter,
_qosReceivePacketFromService)));
}
private bool ParseResponseTopic(string topicName, out string rid, out int status)
{
Match match = _twinResponseTopicRegex.Match(topicName);
if (match.Success)
{
status = Convert.ToInt32(match.Groups[1].Value, CultureInfo.InvariantCulture);
rid = HttpUtility.ParseQueryString(match.Groups[2].Value).Get("$rid");
return true;
}
rid = "";
status = 500;
return false;
}
private async Task<Message> SendTwinRequestAsync(Message request, string rid, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
using var responseReceived = new SemaphoreSlim(0);
Message response = null; ;
ExceptionDispatchInfo responseException = null;
void OnTwinResponse(Message possibleResponse)
{
try
{
if (ParseResponseTopic(possibleResponse.MqttTopicName, out string receivedRid, out int status))
{
if (rid == receivedRid)
{
if (status >= 300)
{
// The Hub team is refactoring the retriable status codes without breaking changes to the existing ones.
// It can be expected that we may bring more retriable codes here in the future.
// Retry for Http status code 429 (too many requests)
if (status == 429)
{
throw new IotHubThrottledException($"Request {rid} was throttled by the server");
}
else
{
throw new IotHubException($"Request {rid} returned status {status}", isTransient: false);
}
}
else
{
response = possibleResponse;
responseReceived.Release();
}
}
}
}
catch (Exception e)
{
responseException = ExceptionDispatchInfo.Capture(e);
responseReceived.Release();
}
}
try
{
_twinResponseEvent += OnTwinResponse;
await SendEventAsync(request, cancellationToken).ConfigureAwait(false);
await responseReceived.WaitAsync(TwinTimeout, cancellationToken).ConfigureAwait(false);
if (responseException != null)
{
responseException.Throw();
}
else if (response == null)
{
throw new TimeoutException($"Response for message {rid} not received");
}
return response;
}
finally
{
_twinResponseEvent -= OnTwinResponse;
}
}
private Func<IPAddress[], int, Task<IChannel>> CreateChannelFactory(IotHubConnectionString iotHubConnectionString, MqttTransportSettings settings, ProductInfo productInfo, ClientOptions options)
{
return async (addresses, port) =>
{
IChannel channel = null;
SslStream StreamFactory(Stream stream) => new SslStream(stream, true, settings.RemoteCertificateValidationCallback);
List<X509Certificate> certs = settings.ClientCertificate == null
? new List<X509Certificate>(0)
: new List<X509Certificate> { settings.ClientCertificate };
SslProtocols protocols = TlsVersions.Instance.Preferred;
#if NET451
// Requires hardcoding in NET451 otherwise yields error:
// Microsoft.Azure.Devices.Client.Exceptions.IotHubCommunicationException: Transient network error occurred, please retry.
// DotNetty.Transport.Channels.ClosedChannelException: I/O error occurred.
if (settings.GetTransportType() == TransportType.Mqtt_Tcp_Only
&& protocols == SslProtocols.None)
{
protocols = TlsVersions.Instance.MinimumTlsVersions;
}
#endif
var clientTlsSettings = new ClientTlsSettings(
protocols,
settings.CertificateRevocationCheck,
certs,
iotHubConnectionString.HostName);
Bootstrap bootstrap = new Bootstrap()
.Group(s_eventLoopGroup.Value)
.Channel<TcpSocketChannel>()
.Option(ChannelOption.TcpNodelay, true)
.Option(ChannelOption.Allocator, UnpooledByteBufferAllocator.Default)
.Handler(new ActionChannelInitializer<ISocketChannel>(ch =>
{
var tlsHandler = new TlsHandler(StreamFactory, clientTlsSettings);
ch.Pipeline.AddLast(
tlsHandler,
MqttEncoder.Instance,
new MqttDecoder(false, MaxMessageSize),
new LoggingHandler(LogLevel.DEBUG),
_mqttIotHubAdapterFactory.Create(this, iotHubConnectionString, settings, productInfo, options));
}));
foreach (IPAddress address in addresses)
{
try
{
if (Logging.IsEnabled)
Logging.Info(this, $"Connecting to {address}", nameof(CreateChannelFactory));
channel = await bootstrap.ConnectAsync(address, port).ConfigureAwait(true);
break;
}
catch (AggregateException ae)
{
ae.Handle((ex) =>
{
if (ex is ConnectException) // We will handle DotNetty.Transport.Channels.ConnectException
{
if (Logging.IsEnabled)
Logging.Error(this, $"ConnectException trying to connect to {address}: {ex}", nameof(CreateChannelFactory));
return true;
}
return false; // Let anything else stop the application.
});
}
catch (ConnectException ex)
{
// same as above, we will handle DotNetty.Transport.Channels.ConnectException
if (Logging.IsEnabled)
Logging.Error(this, $"ConnectException trying to connect to {address}: {ex}", nameof(CreateChannelFactory));
}
}
return channel ?? throw new IotHubCommunicationException("MQTT channel open failed.");
};
}
private Func<IPAddress[], int, Task<IChannel>> CreateWebSocketChannelFactory(
IotHubConnectionString iotHubConnectionString,
MqttTransportSettings settings,
ProductInfo productInfo,
ClientOptions options)
{
return async (address, port) =>
{
string additionalQueryParams = "";
var websocketUri = new Uri($"{WebSocketConstants.Scheme}{iotHubConnectionString.HostName}:{WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}{additionalQueryParams}");
var websocket = new ClientWebSocket();
websocket.Options.AddSubProtocol(WebSocketConstants.SubProtocols.Mqtt);
try
{
if (IsProxyConfigured())
{
// Configure proxy server
websocket.Options.Proxy = _webProxy;
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateWebSocketChannelFactory)} Set ClientWebSocket.Options.Proxy to {_webProxy}");
}
}
catch (PlatformNotSupportedException)
{
// .NET Core 2.0 doesn't support proxy. Ignore this setting.
if (Logging.IsEnabled)
Logging.Error(this, $"{nameof(CreateWebSocketChannelFactory)} PlatformNotSupportedException thrown as .NET Core 2.0 doesn't support proxy");
}
if (settings.WebSocketKeepAlive.HasValue)
{
websocket.Options.KeepAliveInterval = settings.WebSocketKeepAlive.Value;
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateWebSocketChannelFactory)} Set websocket keep-alive to {settings.WebSocketKeepAlive}");
}
if (settings.ClientCertificate != null)
{
websocket.Options.ClientCertificates.Add(settings.ClientCertificate);
}
// Support for RemoteCertificateValidationCallback for ClientWebSocket is introduced in .NET Standard 2.1
#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER
if (settings.RemoteCertificateValidationCallback != null)
{
websocket.Options.RemoteCertificateValidationCallback = settings.RemoteCertificateValidationCallback;
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateWebSocketChannelFactory)} Setting RemoteCertificateValidationCallback");
}
#endif
using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1));
await websocket.ConnectAsync(websocketUri, cts.Token).ConfigureAwait(false);
var clientWebSocketChannel = new ClientWebSocketChannel(null, websocket);
clientWebSocketChannel
.Option(ChannelOption.Allocator, UnpooledByteBufferAllocator.Default)
.Option(ChannelOption.AutoRead, false)
.Option(ChannelOption.RcvbufAllocator, new AdaptiveRecvByteBufAllocator())
.Option(ChannelOption.MessageSizeEstimator, DefaultMessageSizeEstimator.Default)
.Pipeline.AddLast(
MqttEncoder.Instance,
new MqttDecoder(false, MaxMessageSize),
new LoggingHandler(LogLevel.DEBUG),
_mqttIotHubAdapterFactory.Create(this, iotHubConnectionString, settings, productInfo, options));
await s_eventLoopGroup.Value.RegisterAsync(clientWebSocketChannel).ConfigureAwait(true);
return clientWebSocketChannel;
};
}
private void ScheduleCleanup(Func<Task> cleanupTask)
{
Func<Task> currentCleanupFunc = _cleanupFunc;
_cleanupFunc = async () =>
{
await cleanupTask().ConfigureAwait(true);
if (currentCleanupFunc != null)
{
await currentCleanupFunc().ConfigureAwait(true);
}
};
}
private async Task CleanUpAsync()
{
try
{
await _closeRetryPolicy.RunWithRetryAsync(CleanUpImplAsync).ConfigureAwait(true);
}
catch (Exception ex) when (!ex.IsFatal())
{
}
}
private Task CleanUpImplAsync()
{
return _cleanupFunc == null
? TaskHelpers.CompletedTask
: _cleanupFunc();
}
private bool TryStateTransition(TransportState fromState, TransportState toState)
{
return (TransportState)Interlocked.CompareExchange(ref _state, (int)toState, (int)fromState) == fromState;
}
private void EnsureValidState(bool throwIfNotOpen = true)
{
if (State == TransportState.Error)
{
_fatalException.Throw();
}
if (State == TransportState.Closed)
{
Debug.Fail($"{nameof(MqttTransportHandler)}.{nameof(EnsureValidState)}: Attempting to reuse transport after it was closed.");
throw new InvalidOperationException($"Invalid transport state: {State}");
}
if (throwIfNotOpen && (State & TransportState.Open) == 0)
{
throw new IotHubCommunicationException("MQTT connection is not established. Please retry later.");
}
}
private static IEventLoopGroup GetEventLoopGroup()
{
try
{
string envValue = Environment.GetEnvironmentVariable(ProcessorThreadCountVariableName);
if (!string.IsNullOrWhiteSpace(envValue))
{
string processorEventCountValue = Environment.ExpandEnvironmentVariables(envValue);
if (int.TryParse(processorEventCountValue, out int processorThreadCount))
{
if (Logging.IsEnabled)
Logging.Info(null, $"EventLoopGroup threads count {processorThreadCount}.");
return processorThreadCount <= 0 ? new MultithreadEventLoopGroup() :
processorThreadCount == 1 ? (IEventLoopGroup)new SingleThreadEventLoop() :
new MultithreadEventLoopGroup(processorThreadCount);
}
}
}
catch (Exception ex)
{
if (Logging.IsEnabled)
Logging.Info(null, $"Could not read EventLoopGroup threads count {ex}");
return new MultithreadEventLoopGroup();
}
if (Logging.IsEnabled)
Logging.Info(null, "EventLoopGroup threads count was not set.");
return new MultithreadEventLoopGroup();
}
private bool IsProxyConfigured()
{
return _webProxy != null
&& _webProxy != DefaultWebProxySettings.Instance;
}
}
}