DeviceBridge/Services/ConnectionManager.cs (685 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using DeviceBridge.Common.Exceptions;
using DeviceBridge.Models;
using DeviceBridge.Providers;
using Microsoft.Azure.Devices.Client;
using Microsoft.Azure.Devices.Client.Exceptions;
using Microsoft.Azure.Devices.Provisioning.Client;
using Microsoft.Azure.Devices.Provisioning.Client.Transport;
using Microsoft.Azure.Devices.Shared;
using NLog;
namespace DeviceBridge.Services
{
/// <summary>
/// Manages SDK device connections. A connection can have two modes: permanent or temporary.
///
/// A permanent connection is one that should be kept open indefinitely, until the user explicitly decides to close it.
/// This connection type is used for any type of persistent subscription that needs an always-on connection, such as desired property changes.
///
/// A temporary connection is used for point-in-time operations, such as sending telemetry and getting the current device twin.
/// This type of connection lives for a few minutes (currently 9-10 mins) and is automatically closed. It's used to increase the chances
/// of a connection being already open when a point-in-time operation happens but also to make sure that connections don't stay
/// open for too long for silent devices.
///
/// Temporary connections are rewed whenever a new operation happens. Deleting a permanent connection falls back to a temporary connection if one exists.
/// </summary>
public class ConnectionManager : IDisposable, IConnectionManager
{
public const uint DeafultMaxPoolSize = 50; // Up to 50K device connections
public const int TemporaryConnectionMinDurationSeconds = 9 * 60; // 9 minutes
public const int TemporaryConnectionMaxDurationSeconds = TemporaryConnectionMinDurationSeconds + 120; // 11 minutes
public const int ExpiredConnectionCleanupIntervalMs = 10 * 1000; // Every 10 seconds
private const string GlobalDeviceEndpoint = "global.azure-devices-provisioning.net";
// Device client retry options - 36 retries over ~5 minutes for transient errors.
private const int ClientRetryCount = 36;
private const double ClientRetryMinBackoffMs = 100;
private const double ClientRetryMaxBackoffSec = 10;
private const double ClientRetryDeltaBackoffMs = 100;
private readonly string _idScope;
private readonly string _sasKey;
private readonly uint _maxPoolSize;
private readonly Logger _logger;
private readonly IStorageProvider _storageProvider;
private ConcurrentDictionary<string, DeviceClient> _clients = new ConcurrentDictionary<string, DeviceClient>();
private ConcurrentDictionary<string, SemaphoreSlim> _clientSemaphores = new ConcurrentDictionary<string, SemaphoreSlim>();
private ConcurrentDictionary<string, (ConnectionStatus status, ConnectionStatusChangeReason reason)> _clientStatuses = new ConcurrentDictionary<string, (ConnectionStatus status, ConnectionStatusChangeReason reason)>();
private ConcurrentDictionary<string, DateTime> _lastConnectionAttempt = new ConcurrentDictionary<string, DateTime>();
private ConcurrentDictionary<string, (string id, DesiredPropertyUpdateCallback callback)> _desiredPropertyUpdateCallbacks = new ConcurrentDictionary<string, (string id, DesiredPropertyUpdateCallback callback)>();
private ConcurrentDictionary<string, (string id, MethodCallback callback)> _methodCallbacks = new ConcurrentDictionary<string, (string id, MethodCallback callback)>();
private ConcurrentDictionary<string, (string id, ReceiveMessageCallback callback)> _messageCallbacks = new ConcurrentDictionary<string, (string id, ReceiveMessageCallback callback)>();
private ConcurrentDictionary<string, Func<ConnectionStatus, ConnectionStatusChangeReason, Task>> _connectionStatusCallbacks = new ConcurrentDictionary<string, Func<ConnectionStatus, ConnectionStatusChangeReason, Task>>();
private ConcurrentDictionary<string, string> _deviceHubs = new ConcurrentDictionary<string, string>();
private ConcurrentDictionary<string, long> _hasTemporaryConnectionUntil = new ConcurrentDictionary<string, long>(); // Timestamp representing until when a temporary device connection should be kept alive
private ConcurrentDictionary<string, bool> _hasPermanentConnection = new ConcurrentDictionary<string, bool>(); // Indicates if a permanent connection is open for a device
private Func<string, ConnectionStatus, ConnectionStatusChangeReason, Task> _globalConnectionStatusChangeHandler;
public ConnectionManager(Logger logger, string idScope, string sasKey, uint maxPoolSize, IStorageProvider storageProvider)
{
_logger = logger;
_idScope = idScope;
_sasKey = sasKey;
_maxPoolSize = maxPoolSize;
_storageProvider = storageProvider;
// Initialize the in-memory Hub cache and the list of all known hubs with DB data before the service starts.
var dbHubCacheEntries = storageProvider.ListHubCacheEntries(_logger).Result;
_deviceHubs = new ConcurrentDictionary<string, string>(dbHubCacheEntries.Select(e => new KeyValuePair<string, string>(e.DeviceId, e.Hub)));
}
/// <summary>
/// Attempts to cleanup expired temporary connections every 10 seconds.
/// </summary>
public async Task StartExpiredConnectionCleanupAsync()
{
_logger.Info("Started expired connection cleanup task");
while (true)
{
try
{
var currentTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
foreach (var entry in _hasTemporaryConnectionUntil)
{
if (currentTime > entry.Value)
{
var _ = AssertDeviceConnectionClosedAsync(entry.Key, true /* temporary */).ContinueWith(t => _logger.Error(t.Exception, "Failed to close temporary connection for device {deviceId}", entry.Key), TaskContinuationOptions.OnlyOnFaulted);
}
}
}
catch (Exception e)
{
_logger.Error(e, "Failed to cleanup expired connections");
}
// Trigger the next execution
await Task.Delay(ExpiredConnectionCleanupIntervalMs);
}
}
/// <summary>
/// See <see href="https://docs.microsoft.com/en-us/dotnet/api/microsoft.azure.devices.client.connectionstatus?view=azure-dotnet">ConnectionStatus documentation</see>
/// for a detailed description of each status and reason.
/// </summary>
/// <param name="deviceId">Id of the device to get the status for.</param>
/// <returns>The last known connection status of the device or null if the device has never connected.</returns>
public (ConnectionStatus status, ConnectionStatusChangeReason reason)? GetDeviceStatus(string deviceId)
{
if (!_clientStatuses.TryGetValue(deviceId, out (ConnectionStatus status, ConnectionStatusChangeReason reason) status))
{
return null;
}
return status;
}
/// <summary>
/// Gets the list of all devices that attempted to connect since a given timestamp.
/// </summary>
/// <param name="threshold">Timestamp to filter by.</param>
/// <returns>The list of device Ids that attempted to connect since the given timestamp.</returns>
public List<string> GetDevicesThatConnectedSince(DateTime threshold)
{
var deviceIds = new List<string>();
foreach (var lastConnection in _lastConnectionAttempt)
{
if (lastConnection.Value >= threshold)
{
deviceIds.Add(lastConnection.Key);
}
}
return deviceIds;
}
/// <summary>
/// Asserts that a permanent or temporary connection for this device is open.
/// A temporary connection is guaranteed to live for only a few minutes (currently 9-11 minutes).
/// </summary>
/// <param name="deviceId">Id of the device to open a connection for.</param>
/// <param name="temporary">Whether the requested connection is temporary or permanent.</param>
/// <param name="cancellationToken">Optional cancellation token.</param>
public async Task AssertDeviceConnectionOpenAsync(string deviceId, bool temporary = false, CancellationToken? cancellationToken = null)
{
_logger.Info("Attempting to initialize {connectionType} connection for device {deviceId}", temporary ? "Temporary" : "Permanent", deviceId);
_lastConnectionAttempt.AddOrUpdate(deviceId, DateTime.Now, (key, oldValue) => DateTime.Now);
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
if (temporary)
{
// Always renew the connection duration, as the user wants to assert that this connection will live for a few minutes (even if a previous temporary connection exists).
// We use a random factor to spread out when temporary connections expire.
var shouldLiveUntil = DateTimeOffset.UtcNow.ToUnixTimeSeconds() + new Random().Next(TemporaryConnectionMinDurationSeconds, TemporaryConnectionMaxDurationSeconds);
_hasTemporaryConnectionUntil.AddOrUpdate(deviceId, shouldLiveUntil, (key, oldValue) => shouldLiveUntil);
_logger.Info("Temporary connection for device {deviceId} set to live at least until {shouldLiveUntil}", deviceId, DateTimeOffset.FromUnixTimeSeconds(shouldLiveUntil).UtcDateTime);
}
else
{
_hasPermanentConnection.AddOrUpdate(deviceId, true, (key, oldValue) => true);
}
// Dispose the current client if it is in a permanent failure state.
if (_clientStatuses.TryGetValue(deviceId, out (ConnectionStatus status, ConnectionStatusChangeReason reason) currentStatus))
{
// Permanent failure state, taken from https://github.com/Azure-Samples/azure-iot-samples-csharp/tree/master/iot-hub/Samples/device/DeviceReconnectionSample
bool isFailed = currentStatus.status == ConnectionStatus.Disconnected;
if (isFailed && _clients.TryRemove(deviceId, out DeviceClient existingClient))
{
_logger.Info("Disposing existing failed client for device {deviceId}", deviceId);
await existingClient.CloseAsync();
existingClient.Dispose();
existingClient.SetConnectionStatusChangesHandler(null);
}
}
if (_clients.TryGetValue(deviceId, out DeviceClient _))
{
_logger.Info("Connection for device {deviceId} already exists", deviceId);
return;
}
var deviceKey = ComputeDerivedSymmetricKey(Convert.FromBase64String(_sasKey), deviceId);
// If we already know this device's hub, attempt to connect to it first.
if (_deviceHubs.TryGetValue(deviceId, out string knownDeviceHub))
{
try
{
var client = await BuildAndOpenClient(_logger, knownDeviceHub, deviceKey, cancellationToken);
_clients.AddOrUpdate(deviceId, client, (key, oldValue) => client);
return;
}
catch (Exception e)
{
_logger.Error(e, "Failed to connect device {deviceId} to it's old hub ({knownDeviceHub}). Will try DPS registration again.", deviceId, knownDeviceHub);
}
}
// If connecting to the cached Hub failed, try DPS registration.
{
var deviceHub = await DpsRegisterInternalAsync(_logger, deviceId, deviceKey, null, cancellationToken);
_deviceHubs.AddOrUpdate(deviceId, deviceHub, (key, oldValue) => deviceHub);
try
{
await _storageProvider.AddOrUpdateHubCacheEntry(_logger, deviceId, deviceHub);
}
catch (Exception e)
{
// Storing the hub is a best-effort operation.
_logger.Error(e, "Failed to update Hub cache for device {deviceId}", deviceId);
}
var client = await BuildAndOpenClient(_logger, deviceHub, deviceKey, cancellationToken);
_clients.AddOrUpdate(deviceId, client, (key, oldValue) => client);
}
}
finally
{
mutex.Release();
}
async Task<DeviceClient> BuildAndOpenClient(Logger logger, string candidateHub, string deviceKey, CancellationToken? cancellationToken = null)
{
_logger.Info("Attempting to connect device {deviceId} to hub {candidateHub}", deviceId, candidateHub);
DeviceClient client = null;
try
{
var settings = new ITransportSettings[]
{
new AmqpTransportSettings(TransportType.Amqp_Tcp_Only)
{
AmqpConnectionPoolSettings = new AmqpConnectionPoolSettings()
{
Pooling = true,
MaxPoolSize = _maxPoolSize,
},
},
};
var connString = GetDeviceConnectionString(deviceId, candidateHub, deviceKey);
client = DeviceClient.CreateFromConnectionString(connString, settings);
client.SetConnectionStatusChangesHandler(BuildConnectionStatusChangeHandler(deviceId));
client.SetRetryPolicy(new ExponentialBackoff(ClientRetryCount, TimeSpan.FromMilliseconds(ClientRetryMinBackoffMs), TimeSpan.FromSeconds(ClientRetryMaxBackoffSec), TimeSpan.FromMilliseconds(ClientRetryDeltaBackoffMs)));
// If a desired property callback exists, register it.
if (_desiredPropertyUpdateCallbacks.TryGetValue(deviceId, out var desiredPropertyUpdateCallback))
{
await client.SetDesiredPropertyUpdateCallbackAsync(desiredPropertyUpdateCallback.callback, null);
}
// If a method callback exists, register it.
if (_methodCallbacks.TryGetValue(deviceId, out var methodCallback))
{
await client.SetMethodDefaultHandlerAsync(methodCallback.callback, null);
}
// If a C2DMessage callback exists, register it.
if (_messageCallbacks.TryGetValue(deviceId, out var messageCallback))
{
await client.SetReceiveMessageHandlerAsync(messageCallback.callback, null);
}
if (cancellationToken.HasValue)
{
await client.OpenAsync(cancellationToken.Value);
}
else
{
await client.OpenAsync();
}
_logger.Info("Device {deviceId} connected to hub {candidateHub}", deviceId, candidateHub);
return client;
}
catch (Exception e)
{
// Dispose of the failed client to make sure it doesn't retry internally.
client?.Dispose();
client?.SetConnectionStatusChangesHandler(null);
throw e;
}
}
}
/// <summary>
/// Asserts that the permanent or temporary connection for this device is closed. The temporary connection is only closed
/// if it has expired. The underlying connection is not actually closed if we're trying to delete a permanent connection and
/// a temporary one exists or vice-versa.
/// </summary>
/// <param name="deviceId">Id of the decide for which the connection should be closed.</param>
/// <param name="temporary">Whether the temporary or permanent connection should be closed.</param>
public async Task AssertDeviceConnectionClosedAsync(string deviceId, bool temporary = false)
{
_logger.Info("Attempting to tear down {connectionType} connection for device {deviceId}", temporary ? "Temporary" : "Permanent", deviceId);
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Attempt to remove the permanent or temporary connection from the list
if (temporary)
{
var currentTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
if (_hasTemporaryConnectionUntil.TryGetValue(deviceId, out long shouldLiveUntil))
{
if (currentTime > shouldLiveUntil)
{
_hasTemporaryConnectionUntil.TryRemove(deviceId, out _);
}
else
{
_logger.Info("Attempted to tear down temporary connection for device {deviceId}, but connection has not yet expired", deviceId);
return;
}
}
else
{
_logger.Info("Attempted to tear down temporary connection for device {deviceId}, but a temporary connection wasn't found", deviceId);
return;
}
}
else
{
if (!_hasPermanentConnection.TryRemove(deviceId, out _))
{
_logger.Info("Attempted to tear down permanent connection for device {deviceId}, but a permanent connection wasn't found", deviceId);
return;
}
}
// Do not actually close client if a permanent or temporary connection still exists.
if (_hasPermanentConnection.TryGetValue(deviceId, out _) || _hasTemporaryConnectionUntil.TryGetValue(deviceId, out _))
{
_logger.Info("Attempted to tear down connection for device {deviceId}, but a permanent or temporary connection for this device still exists.", deviceId);
return;
}
if (!_clients.TryRemove(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} doesn't exist", deviceId);
return;
}
await client.CloseAsync();
client.Dispose();
client.SetConnectionStatusChangesHandler(null);
_logger.Info("Closed connection for device {deviceId}", deviceId);
}
finally
{
mutex.Release();
}
}
/// <summary>
/// Performs a standalone DPS registration (not part of a device connection). The registration data is cached for future connections.
/// </summary>
/// <param name="logger">Logger instance to use.</param>
/// <param name="deviceId">Id of the device to register.</param>
/// <param name="modelId">Optional model Id to assign the device to.</param>
/// <param name="cancellationToken">Optional cancellation token.</param>
public async Task StandaloneDpsRegistrationAsync(Logger logger, string deviceId, string modelId = null, CancellationToken? cancellationToken = null)
{
var deviceKey = ComputeDerivedSymmetricKey(Convert.FromBase64String(_sasKey), deviceId);
var deviceHub = await DpsRegisterInternalAsync(logger, deviceId, deviceKey, modelId, cancellationToken);
// Cache this hub for later use.
_deviceHubs.AddOrUpdate(deviceId, deviceHub, (key, oldValue) => deviceHub);
try
{
await _storageProvider.AddOrUpdateHubCacheEntry(_logger, deviceId, deviceHub);
}
catch (Exception e)
{
// Storing the hub is a best-effort operation.
logger.Error(e, "Failed to update Hub cache for device {deviceId}", deviceId);
}
}
public async Task SendEventAsync(Logger logger, string deviceId, IDictionary<string, object> payload, CancellationToken cancellationToken, IDictionary<string, string> properties = null, string componentName = null, DateTime? creationTimeUtc = null)
{
logger.Info("Sending event for device {deviceId}", deviceId);
// This method expects a connection to have been previously established
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
var e = new DeviceConnectionNotFoundException(deviceId);
logger.Error(e, "Tried to send event for device {deviceId} but an active connection was not found", deviceId);
throw e;
}
var data = JsonSerializer.Serialize(payload);
var eventMessage = new Message(Encoding.UTF8.GetBytes(data))
{
ContentEncoding = Encoding.UTF8.WebName,
ContentType = "application/json",
};
if (componentName != null)
{
eventMessage.ComponentName = componentName;
}
if (properties != null)
{
foreach (var property in properties)
{
eventMessage.Properties.Add(property.Key, property.Value);
}
}
if (creationTimeUtc.HasValue)
{
eventMessage.CreationTimeUtc = creationTimeUtc.Value;
}
try
{
await client.SendEventAsync(eventMessage, cancellationToken);
}
catch (Exception e)
{
throw TranslateSdkException(e, deviceId);
}
logger.Info("Event for device {deviceId} sent successfully", deviceId);
}
public async Task<Microsoft.Azure.Devices.Shared.Twin> GetTwinAsync(Logger logger, string deviceId, CancellationToken cancellationToken)
{
logger.Info("Getting twin for device {deviceId}", deviceId);
// This method expects a connection to have been previously established
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
var e = new DeviceConnectionNotFoundException(deviceId);
logger.Error(e, "Tried to get twin for device {deviceId} but an active connection was not found", deviceId);
throw e;
}
Microsoft.Azure.Devices.Shared.Twin twin;
try
{
twin = await client.GetTwinAsync(cancellationToken);
}
catch (Exception e)
{
throw TranslateSdkException(e, deviceId);
}
logger.Info("Successfully got twin for device {deviceId}", deviceId);
return twin;
}
public async Task UpdateReportedPropertiesAsync(Logger logger, string deviceId, IDictionary<string, object> patch, CancellationToken cancellationToken)
{
logger.Info("Updating reported properties for device {deviceId}", deviceId);
// This method expects a connection to have been previously established
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
var e = new DeviceConnectionNotFoundException(deviceId);
logger.Error(e, "Tried to update reported properties for device {deviceId} but an active connection was not found", deviceId);
throw e;
}
TwinCollection reportedProperties = new TwinCollection(JsonSerializer.Serialize(patch));
try
{
await client.UpdateReportedPropertiesAsync(reportedProperties, cancellationToken);
}
catch (Exception e)
{
throw TranslateSdkException(e, deviceId);
}
logger.Info("Successfully updated reported properties for device {deviceId}", deviceId);
}
/// <summary>
/// Sets the desired property change callback. The callback is not tied to a connection lifetime and will be active whenever the device
/// status is marked as online.
/// </summary>
/// <param name="deviceId">Id to the device to set the callback for.</param>
/// <param name="id">string identifying the callback, for tracking purposes.</param>
/// <param name="callback">The callback to be called when a desired property update is received.</param>
public async Task SetDesiredPropertyUpdateCallbackAsync(string deviceId, string id, DesiredPropertyUpdateCallback callback)
{
_logger.Info("Attempting to set desired property change handler for device {deviceId}", deviceId);
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
// We need to synchronize this with client creation/close so a race condition doesn't cause us to miss the
// callback registration on a client that is being currently created.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Save the callback so it can be registered whenever a client is created
_desiredPropertyUpdateCallbacks.AddOrUpdate(deviceId, (id, callback), (key, oldValue) => (id, callback));
// If a client currently exists, register the callback
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to set desired property change callback. Callback will be registered whenever a new client is created", deviceId);
return;
}
await client.SetDesiredPropertyUpdateCallbackAsync(callback, null);
}
finally
{
mutex.Release();
}
}
public string GetCurrentDesiredPropertyUpdateCallbackId(string deviceId)
{
if (!_desiredPropertyUpdateCallbacks.TryGetValue(deviceId, out var desiredPropertyUpdateCallback))
{
return null;
}
return desiredPropertyUpdateCallback.id;
}
public async Task RemoveDesiredPropertyUpdateCallbackAsync(string deviceId)
{
_logger.Info("Attempting remove desired property change handler for device {deviceId}", deviceId);
// We need to synchronize this with client creation/close so a race condition doesn't cause us to add the
// callback to a client that is being currently created but not yet in the clients list.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Remove the callback so it's not registered in new clients
if (_desiredPropertyUpdateCallbacks.TryRemove(deviceId, out _))
{
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to remove desired property change callback", deviceId);
return;
}
// The device SDK does not accept removing a property change callback (or passing null), so we just register an empty one.
await client.SetDesiredPropertyUpdateCallbackAsync((_, __) => Task.CompletedTask, null);
}
else
{
_logger.Info("Tried to remove desired property change handler for device {deviceId}, but a handler was not registered", deviceId);
}
}
finally
{
mutex.Release();
}
}
/// <summary>
/// Sets the direct method callback for a device. The callback is not tied to a connection lifetime and will be active whenever the device
/// status is marked as online.
/// </summary>
/// <param name="deviceId">Id to the device to set the callback for.</param>
/// <param name="id">string identifying the callback, for tracking purposes.</param>
/// <param name="callback">The callback to be called when a method invocation is received.</param>
public async Task SetMethodCallbackAsync(string deviceId, string id, MethodCallback callback)
{
_logger.Info("Attempting to set method handler for device {deviceId}", deviceId);
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
// We need to synchronize this with client creation/close so a race condition doesn't cause us to miss the
// callback registration on a client that is being currently created.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Save the callback so it can be registered whenever a client is created
_methodCallbacks.AddOrUpdate(deviceId, (id, callback), (key, oldValue) => (id, callback));
// If a client currently exists, register the callback
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to set method callback. Callback will be registered whenever a new client is created", deviceId);
return;
}
await client.SetMethodDefaultHandlerAsync(callback, null);
}
finally
{
mutex.Release();
}
}
public string GetCurrentMethodCallbackId(string deviceId)
{
if (!_methodCallbacks.TryGetValue(deviceId, out var methodCallback))
{
return null;
}
return methodCallback.id;
}
public async Task RemoveMethodCallbackAsync(string deviceId)
{
_logger.Info("Attempting remove method handler for device {deviceId}", deviceId);
// We need to synchronize this with client creation/close so a race condition doesn't cause us to add the
// callback to a client that is being currently created but not yet in the clients list.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Remove the callback so it's not registered in new clients
if (_methodCallbacks.TryRemove(deviceId, out _))
{
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to remove method callback", deviceId);
return;
}
await client.SetMethodDefaultHandlerAsync(null, null);
}
else
{
_logger.Info("Tried to remove method handler for device {deviceId}, but a handler was not registered", deviceId);
}
}
finally
{
mutex.Release();
}
}
/// <summary>
/// Sets the direct message callback for a device. The callback is not tied to a connection lifetime and will be active whenever the device
/// status is marked as online.
/// </summary>
/// <param name="deviceId">Id to the device to set the callback for.</param>
/// <param name="id">string identifying the callback, for tracking purposes.</param>
/// <param name="callback">The callback to be called when a C2D message is received.</param>
public async Task SetMessageCallbackAsync(string deviceId, string id, Func<Message, Task<ReceiveMessageCallbackStatus>> callback)
{
_logger.Info("Attempting to set C2DMessage handler for device {deviceId}", deviceId);
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
// We need to synchronize this with client creation/close so a race condition doesn't cause us to miss the
// callback registration on a client that is being currently created.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
async Task OnC2DMessageReceived(Message receivedMessage, object userContext)
{
if (!_clients.TryGetValue(deviceId, out DeviceClient tmpClient))
{
_logger.Info("Unable to find client for device {deviceId}, message will not be completed, rejected or abandoned.", deviceId);
return;
}
try
{
var status = await callback(receivedMessage);
if (status == ReceiveMessageCallbackStatus.Accept)
{
await tmpClient.CompleteAsync(receivedMessage);
}
else if (status == ReceiveMessageCallbackStatus.Abandon)
{
await tmpClient.AbandonAsync(receivedMessage);
}
else
{
await tmpClient.RejectAsync(receivedMessage);
}
}
catch
{
await tmpClient.AbandonAsync(receivedMessage);
}
}
_messageCallbacks.AddOrUpdate(deviceId, (id, OnC2DMessageReceived), (key, oldValue) => (id, OnC2DMessageReceived));
// If a client currently exists, register the callback
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to set C2DMessage callback. Callback will be registered whenever a new client is created", deviceId);
return;
}
await client.SetReceiveMessageHandlerAsync(OnC2DMessageReceived, client);
}
finally
{
mutex.Release();
}
}
public string GetCurrentMessageCallbackId(string deviceId)
{
if (!_messageCallbacks.TryGetValue(deviceId, out var messageCallback))
{
return null;
}
return messageCallback.id;
}
public async Task RemoveMessageCallbackAsync(string deviceId)
{
_logger.Info("Attempting remove C2DMessage handler for device {deviceId}", deviceId);
// We need to synchronize this with client creation/close so a race condition doesn't cause us to add the
// callback to a client that is being currently created but not yet in the clients list.
var mutex = _clientSemaphores.GetOrAdd(deviceId, new SemaphoreSlim(1, 1));
await mutex.WaitAsync();
try
{
_logger.Info("Acquired connection lock for device {deviceId}", deviceId);
// Remove the callback so it's not registered in new clients
if (_messageCallbacks.TryRemove(deviceId, out _))
{
if (!_clients.TryGetValue(deviceId, out DeviceClient client))
{
_logger.Info("Connection for device {deviceId} not found while trying to remove C2DMessage callback", deviceId);
return;
}
await client.SetReceiveMessageHandlerAsync(null, null);
}
else
{
_logger.Info("Tried to remove C2DMessage handler for device {deviceId}, but a handler was not registered", deviceId);
}
}
finally
{
mutex.Release();
}
}
/// <summary>
/// Sets the global connection status change handler.
/// </summary>
/// <param name="callback">Callback to be called when the status of a device connection changes.</param>
public void SetGlobalConnectionStatusCallback(Func<string, ConnectionStatus, ConnectionStatusChangeReason, Task> callback)
{
_logger.Info("Setting global connection status handler");
_globalConnectionStatusChangeHandler = callback;
}
/// <summary>
/// Sets the connection status change handler for a device.
/// </summary>
/// <param name="deviceId">Id of the device to set the callback for.</param>
/// <param name="callback">Callback to be called when the status of the device connection changes.</param>
public void SetConnectionStatusCallback(string deviceId, Func<ConnectionStatus, ConnectionStatusChangeReason, Task> callback)
{
_logger.Info("Attempting to set connection status handler for device {deviceId}", deviceId);
if (callback == null)
{
throw new ArgumentNullException(nameof(callback));
}
_connectionStatusCallbacks.AddOrUpdate(deviceId, callback, (key, oldValue) => callback);
}
public void RemoveConnectionStatusCallback(string deviceId)
{
_logger.Info("Attempting remove connection status handler for device {deviceId}", deviceId);
if (!_connectionStatusCallbacks.TryRemove(deviceId, out _))
{
_logger.Info("Tried to remove connection status handler for device {deviceId}, but a handler was not registered", deviceId);
}
}
/// <summary>
/// Attempts to gracefully shutdown all SDK connections.
/// </summary>
public void Dispose()
{
_logger.Info("Disposing all SDK clients");
foreach (var entry in _clients)
{
entry.Value.Dispose();
}
}
private static Exception TranslateSdkException(Exception e, string deviceId)
{
if (e is IotHubCommunicationException && e.InnerException is TimeoutException)
{
throw new DeviceSdkTimeoutException(deviceId);
}
else
{
throw e; // If we don't know this particular exception type, throw as is.
}
}
/// <summary>
/// Internal wrapper for DPS registration.
/// </summary>
/// <exception cref="DpsRegistrationFailedWithUnknownStatusException">If the final registration status is not "assigned".</exception>
/// <param name="logger">Logger instance to use.</param>
/// <param name="deviceId">Id of the device to register.</param>
/// <param name="deviceKey">Key for the device.</param>
/// <param name="modelId">Optional model Id to be passed to DPS.</param>
/// <param name="cancellationToken">Optional cancellation token.</param>
/// <returns>The assigned hub for this device.</returns>
private async Task<string> DpsRegisterInternalAsync(Logger logger, string deviceId, string deviceKey, string modelId = null, CancellationToken? cancellationToken = null)
{
using (var security = new SecurityProviderSymmetricKey(deviceId, deviceKey, null))
using (var transport = new ProvisioningTransportHandlerHttp())
{
logger.Info("Attempting DPS registration for device {deviceId}, model Id {modelId}", deviceId, modelId);
ProvisioningDeviceClient provisioningClient = ProvisioningDeviceClient.Create(GlobalDeviceEndpoint, _idScope, security, transport);
DeviceRegistrationResult result;
// If a model Id was provided, pass it along to DPS.
if (modelId != null)
{
var pnpPayload = new ProvisioningRegistrationAdditionalData
{
JsonData = $"{{\"modelId\":\"{modelId}\"}}",
};
result = cancellationToken.HasValue ? await provisioningClient.RegisterAsync(pnpPayload, cancellationToken.Value) : await provisioningClient.RegisterAsync(pnpPayload);
}
else
{
result = cancellationToken.HasValue ? await provisioningClient.RegisterAsync(cancellationToken.Value) : await provisioningClient.RegisterAsync();
}
if (result.Status == ProvisioningRegistrationStatusType.Assigned)
{
logger.Info("DPS registration successful for device {deviceId}. Assigned to hub {deviceHub}", deviceId, result.AssignedHub);
return result.AssignedHub;
}
else
{
var e = new DpsRegistrationFailedWithUnknownStatusException(deviceId, result.Status.ToString(), result.Substatus.ToString(), result.ErrorCode, result.ErrorMessage);
logger.Error(e);
throw e;
}
}
}
/// <summary>
/// Builds a connection change handler for a specific deviceId, which optionally calls a custom callback.
/// </summary>
/// <param name="deviceId">Id of the device for which to build the callback.</param>
/// <returns>The connection status change handler.</returns>
private ConnectionStatusChangesHandler BuildConnectionStatusChangeHandler(string deviceId)
{
return (ConnectionStatus status, ConnectionStatusChangeReason reason) =>
{
_logger.Info("Connection status of device {deviceId} changed: status = {status}, reason = {reason}", deviceId, status, reason);
_clientStatuses.AddOrUpdate(deviceId, (status, reason), (key, oldValue) => (status, reason));
// If a custom callback exists, call it asynchronously.
if (_connectionStatusCallbacks.TryGetValue(deviceId, out var statusCallback))
{
var _ = statusCallback(status, reason).ContinueWith(t => _logger.Error(t.Exception, "Failed to execute custom connection status callback for device {deviceId}", deviceId), TaskContinuationOptions.OnlyOnFaulted);
}
// Execute the global status change handler if one was defined.
if (_globalConnectionStatusChangeHandler != null)
{
var _ = _globalConnectionStatusChangeHandler(deviceId, status, reason).ContinueWith(t => _logger.Error(t.Exception, "Failed to execute global connection status callback for device {deviceId}", deviceId), TaskContinuationOptions.OnlyOnFaulted);
}
};
}
private string ComputeDerivedSymmetricKey(byte[] masterKey, string registrationId)
{
using (var hmac = new HMACSHA256(masterKey))
{
return Convert.ToBase64String(hmac.ComputeHash(Encoding.UTF8.GetBytes(registrationId)));
}
}
private string GetDeviceConnectionString(string deviceId, string deviceHub, string derivedKey)
{
return string.Format("HostName={0};DeviceId={1};SharedAccessKey={2}", deviceHub, deviceId, derivedKey);
}
}
}