DeviceBridge/Providers/StorageProvider.cs (341 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. using System; using System.Collections.Generic; using System.Data; using System.Data.SqlClient; using System.Threading; using System.Threading.Tasks; using DeviceBridge.Common.Exceptions; using DeviceBridge.Models; using DeviceBridge.Services; using NLog; namespace DeviceBridge.Providers { public class StorageProvider : IStorageProvider { /// <summary> /// Taken from https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/database-engine-events-and-errors. /// </summary> private const int TableNotFoundErrorNumber = 208; private const int StoredProcedureNotFoundErrorNumber = 2812; private const int DefaultPageSIze = 1000; private const int BulkCopyBatchTimeout = 60; private const int BulkCopyBatchSize = 1000; private readonly string _connectionString; private readonly IEncryptionService _encryptionService; public StorageProvider(string connectionString, IEncryptionService encryptionService) { _connectionString = connectionString; _encryptionService = encryptionService; } /// <summary> /// Lists all active subscriptions of all types ordered by device Id. /// </summary> /// <param name="logger">Logger to be used.</param> /// <returns>List of all subscriptions of all types ordered by device Id.</returns> public async Task<List<DeviceSubscription>> ListAllSubscriptionsOrderedByDeviceId(Logger logger) { try { logger.Info("Getting all subscriptions in the DB"); using SqlConnection connection = new SqlConnection(_connectionString); await connection.OpenAsync(); var subscriptions = new List<DeviceSubscription>(); int lastPageSize; int pageIndex = 0; do { logger.Info("Fetching page {pageIndex} of device subscriptions", pageIndex); using SqlCommand command = new SqlCommand("getDeviceSubscriptionsPaged", connection) { CommandType = CommandType.StoredProcedure, }; command.Parameters.Add(new SqlParameter("@PageIndex", pageIndex++)); command.Parameters.Add(new SqlParameter("@RowsPerPage", DefaultPageSIze)); using SqlDataReader reader = await command.ExecuteReaderAsync(); var itemCountBeforePage = subscriptions.Count; while (await reader.ReadAsync()) { subscriptions.Add(new DeviceSubscription() { DeviceId = reader["DeviceId"].ToString(), SubscriptionType = DeviceSubscriptionType.FromString(reader["SubscriptionType"].ToString()), CallbackUrl = await _encryptionService.Decrypt(logger, reader["CallbackUrl"].ToString()), CreatedAt = (DateTime)reader["CreatedAt"], }); } lastPageSize = subscriptions.Count - itemCountBeforePage; } while (lastPageSize > 0); logger.Info("Found {subscriptionCount} subscriptions", subscriptions.Count); return subscriptions; } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Lists all active subscriptions of all types for a device. /// </summary> /// <param name="logger">Logger to be used.</param> /// <param name="deviceId">Id of the device to get the subscriptions for.</param> /// <returns>List of subscriptions for the given device.</returns> public async Task<List<DeviceSubscription>> ListDeviceSubscriptions(Logger logger, string deviceId) { try { logger.Info("Getting all subscriptions for device {deviceId}", deviceId); var sql = "SELECT * FROM DeviceSubscriptions WHERE DeviceId = @DeviceId"; using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(sql, connection); command.Parameters.Add(new SqlParameter("DeviceId", deviceId)); await connection.OpenAsync(); using SqlDataReader reader = await command.ExecuteReaderAsync(); List<DeviceSubscription> subscriptions = new List<DeviceSubscription>(); while (await reader.ReadAsync()) { subscriptions.Add(new DeviceSubscription() { DeviceId = reader["DeviceId"].ToString(), SubscriptionType = DeviceSubscriptionType.FromString(reader["SubscriptionType"].ToString()), CallbackUrl = await _encryptionService.Decrypt(logger, reader["CallbackUrl"].ToString()), CreatedAt = (DateTime)reader["CreatedAt"], }); } logger.Info("Found {subscriptionCount} subscriptions for device {deviceId}", subscriptions.Count, deviceId); return subscriptions; } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Gets an active subscription of the specified type for a device, if one exists. /// </summary> /// <param name="logger">Logger to be used.</param> /// <param name="deviceId">Id of the device to get the subscription for.</param> /// <param name="subscriptionType">Type of the subscription to get.</param> /// <param name="cancellationToken">Cancellation token.</param> /// <returns>The subscription, if exists. Null otherwise.</returns> public async Task<DeviceSubscription> GetDeviceSubscription(Logger logger, string deviceId, DeviceSubscriptionType subscriptionType, CancellationToken cancellationToken) { try { logger.Info("Getting {subscriptionType} subscription for device {deviceId}", subscriptionType, deviceId); var sql = "SELECT * FROM DeviceSubscriptions WHERE DeviceId = @DeviceId AND SubscriptionType = @SubscriptionType"; using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(sql, connection); command.Parameters.Add(new SqlParameter("DeviceId", deviceId)); command.Parameters.Add(new SqlParameter("SubscriptionType", subscriptionType.ToString())); await connection.OpenAsync(cancellationToken); using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken); if (await reader.ReadAsync(cancellationToken)) { logger.Info("Got {subscriptionType} for device {deviceId}", subscriptionType, deviceId); return new DeviceSubscription() { DeviceId = reader["DeviceId"].ToString(), SubscriptionType = DeviceSubscriptionType.FromString(reader["SubscriptionType"].ToString()), CallbackUrl = await _encryptionService.Decrypt(logger, reader["CallbackUrl"].ToString()), CreatedAt = (DateTime)reader["CreatedAt"], }; } else { logger.Info("No {subscriptionType} subscription found for device {deviceId}", subscriptionType, deviceId); return null; } } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Creates a subscription of the given type for the given device. If one already exists, it's updated with a new creation time and callback URL. /// Returns the created or updated subscription. /// </summary> /// <param name="logger">Logger to be used.</param> /// <param name="deviceId">Id of the device to create the subscription for.</param> /// <param name="subscriptionType">Type of the subscription to be created.</param> /// <param name="callbackUrl">Callback URL of the subscription.</param> /// <param name="cancellationToken">Cancellation token.</param> /// <returns>The created subscription.</returns> public async Task<DeviceSubscription> CreateOrUpdateDeviceSubscription(Logger logger, string deviceId, DeviceSubscriptionType subscriptionType, string callbackUrl, CancellationToken cancellationToken) { try { logger.Info("Creating or updating {subscriptionType} subscription for device {deviceId}", subscriptionType, deviceId); using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand("upsertDeviceSubscription", connection) { CommandType = CommandType.StoredProcedure, }; command.Parameters.Add(new SqlParameter("@DeviceId", deviceId)); command.Parameters.Add(new SqlParameter("@SubscriptionType", subscriptionType.ToString())); command.Parameters.Add(new SqlParameter("@CallbackUrl", await _encryptionService.Encrypt(logger, callbackUrl))); command.Parameters.Add(new SqlParameter("@CreatedAt", SqlDbType.DateTime)).Direction = ParameterDirection.Output; await connection.OpenAsync(cancellationToken); await command.ExecuteNonQueryAsync(cancellationToken); logger.Info("Created or updated {subscriptionType} subscription for device {deviceId}", subscriptionType, deviceId); return new DeviceSubscription() { DeviceId = deviceId, SubscriptionType = subscriptionType, CallbackUrl = callbackUrl, CreatedAt = (DateTime)command.Parameters["@CreatedAt"].Value, }; } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Deletes the subscription of the given type for a device, if one exists. /// </summary> /// <param name="logger">Logger to be used.</param> /// <param name="deviceId">Id of the device to delete the subscription for.</param> /// <param name="subscriptionType">Type of the subscription to be deleted.</param> /// <param name="cancellationToken">Cancellation token.</param> public async Task DeleteDeviceSubscription(Logger logger, string deviceId, DeviceSubscriptionType subscriptionType, CancellationToken cancellationToken) { try { logger.Info("Deleting {subscriptionType} subscription for device {deviceId}", subscriptionType, deviceId); var sql = "DELETE FROM DeviceSubscriptions WHERE DeviceId = @DeviceId AND SubscriptionType = @SubscriptionType"; using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(sql, connection); command.Parameters.Add(new SqlParameter("DeviceId", deviceId)); command.Parameters.Add(new SqlParameter("SubscriptionType", subscriptionType.ToString())); await connection.OpenAsync(cancellationToken); await command.ExecuteNonQueryAsync(cancellationToken); logger.Info("Deleted {subscriptionType} subscription for device {deviceId}", subscriptionType, deviceId); } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Deletes from the hub cache any device that doesn't have a subscription and hasn't attempted to open a connection in the past week. /// </summary> /// <param name="logger">Logger to be used.</param> public async Task GcHubCache(Logger logger) { try { logger.Info("Running Hub cache GC"); var sql = @"DELETE c FROM HubCache c LEFT JOIN DeviceSubscriptions s ON s.DeviceId = c.DeviceId WHERE (s.DeviceId IS NULL) AND (c.RenewedAt < DATEADD(day, -7, GETUTCDATE()))"; using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(sql, connection); await connection.OpenAsync(); var affectedRows = await command.ExecuteNonQueryAsync(); logger.Info("Successfully cleaned up {hubCount} Hubs during Hub cache GC", affectedRows); } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Renews the Hub cache timestamp for a list of devices. /// </summary> /// <param name="logger">The logger instance to use.</param> /// <param name="deviceIds">List of device Ids to renew.</param> public async Task RenewHubCacheEntries(Logger logger, List<string> deviceIds) { try { logger.Info("Renewing Hub cache entries for {count} devices", deviceIds.Count); // Add device Ids to a Data Table that we'll bulk copy to the DB. var dt = new DataTable(); dt.Columns.Add("DeviceId"); foreach (var deviceId in deviceIds) { var row = dt.NewRow(); row["DeviceId"] = deviceId; dt.Rows.Add(row); } using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(string.Empty, connection); await connection.OpenAsync(); // Create a target temp table. command.CommandText = "CREATE TABLE #CacheEntriesToRenewTmpTable(DeviceId VARCHAR(255) NOT NULL PRIMARY KEY)"; await command.ExecuteNonQueryAsync(); // Bulk copy the device Ids to renew to the temp table, 1000 records at a time. using SqlBulkCopy bulkcopy = new SqlBulkCopy(connection); bulkcopy.BulkCopyTimeout = BulkCopyBatchTimeout; bulkcopy.BatchSize = BulkCopyBatchSize; bulkcopy.DestinationTableName = "#CacheEntriesToRenewTmpTable"; await bulkcopy.WriteToServerAsync(dt); // Renew the Hub cache timestamp for every device Id in the temp table. command.CommandTimeout = 300; // The operation should take no longer than 5 minutes command.CommandText = @"UPDATE HubCache SET RenewedAt = GETUTCDATE() FROM HubCache INNER JOIN #CacheEntriesToRenewTmpTable Temp ON (Temp.DeviceId = HubCache.DeviceId) DROP TABLE #CacheEntriesToRenewTmpTable"; await command.ExecuteNonQueryAsync(); } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Adds or updates a Hub cache entry for a device. /// </summary> /// <param name="logger">Logger to be used.</param> /// <param name="deviceId">Id of the device for the new cache entry.</param> /// <param name="hub">Hub to be added to the cache entry for the device.</param> public async Task AddOrUpdateHubCacheEntry(Logger logger, string deviceId, string hub) { try { logger.Info("Adding or updating Hub cache entry for device {deviceId} ({hub})", deviceId, hub); using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand("upsertHubCacheEntry", connection) { CommandType = CommandType.StoredProcedure, }; command.Parameters.Add(new SqlParameter("@DeviceId", deviceId)); command.Parameters.Add(new SqlParameter("@Hub", hub)); await connection.OpenAsync(); await command.ExecuteNonQueryAsync(); logger.Info("Added or updated Hub cache entry for device {deviceId}", deviceId); } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Lists all entries in the Hub cache. /// </summary> /// <param name="logger">Logger to be used.</param> /// <returns>List of all entries in the DB hub cache.</returns> public async Task<List<HubCacheEntry>> ListHubCacheEntries(Logger logger) { try { logger.Info("Getting all entries in the hub cache"); using SqlConnection connection = new SqlConnection(_connectionString); await connection.OpenAsync(); var allEntries = new List<HubCacheEntry>(); int lastPageSize; int pageIndex = 0; do { logger.Info("Fetching page {pageIndex} of Hub cache entries", pageIndex); using SqlCommand command = new SqlCommand("getHubCacheEntriesPaged", connection) { CommandType = CommandType.StoredProcedure, }; command.Parameters.Add(new SqlParameter("@PageIndex", pageIndex++)); command.Parameters.Add(new SqlParameter("@RowsPerPage", DefaultPageSIze)); using SqlDataReader reader = await command.ExecuteReaderAsync(); var itemCountBeforePage = allEntries.Count; while (await reader.ReadAsync()) { allEntries.Add(new HubCacheEntry() { DeviceId = reader["DeviceId"].ToString(), Hub = reader["Hub"].ToString(), }); } lastPageSize = allEntries.Count - itemCountBeforePage; } while (lastPageSize > 0); logger.Info("Found {hubCacheEntriesCount} Hub cache entries", allEntries.Count); return allEntries; } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Executes an arbitrary SQL command against the DB. /// </summary> /// <param name="logger">Logger instance to use.</param> /// <param name="sql">SQL command to run.</param> public async Task Exec(Logger logger, string sql) { try { logger.Info("Executing SQL command"); using SqlConnection connection = new SqlConnection(_connectionString); using SqlCommand command = new SqlCommand(sql, connection); await connection.OpenAsync(); await command.ExecuteNonQueryAsync(); logger.Info("SQL command executed successfully"); } catch (Exception e) { throw TranslateSqlException(e); } } /// <summary> /// Translates SQL exceptions into service exceptions. /// </summary> /// <param name="e">Original SQL exception.</param> /// <returns>The translated service exception.</returns> private static BridgeException TranslateSqlException(Exception e) { if (e is SqlException sqlException && (sqlException.Number == StoredProcedureNotFoundErrorNumber || sqlException.Number == TableNotFoundErrorNumber)) { return new StorageSetupIncompleteException(e); } return new UnknownStorageException(e); } } }