DeviceBridgeTests/Providers/StorageProviderTests.cs (325 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. using System; using System.Collections.Generic; using System.Data; using System.Data.SqlClient; using System.Linq; using System.Text.RegularExpressions; using System.Threading.Tasks; using DeviceBridge.Models; using DeviceBridge.Services; using Microsoft.QualityTools.Testing.Fakes; using Moq; using NLog; using NUnit.Framework; namespace DeviceBridge.Providers.Tests { [TestFixture] public class StorageProviderTests { private Mock<IEncryptionService> _encriptionServiceMock; private StorageProvider _storageProvider; [SetUp] public async Task Setup() { _encriptionServiceMock = new Mock<IEncryptionService>(); _encriptionServiceMock.Setup(p => p.Encrypt(It.IsAny<Logger>(), It.IsAny<string>())).Returns((Logger _, string s) => Task.FromResult(s)); _encriptionServiceMock.Setup(p => p.Decrypt(It.IsAny<Logger>(), It.IsAny<string>())).Returns((Logger _, string s) => Task.FromResult(s)); _storageProvider = new StorageProvider("", _encriptionServiceMock.Object); } [Test] [Description("Continuously calls getDeviceSubscriptionsPaged stored procedure to return all subscriptions in the DB")] public async Task ListAllSubscriptionsOrderedByDeviceId() { using (ShimsContext.Create()) { _encriptionServiceMock.Invocations.Clear(); // Return 55 subscriptions. var testDateTime = DateTime.Now; var testSub = GetTestSubscription(testDateTime); var allSubs = Enumerable.Repeat(testSub, 55).ToList(); List<Dictionary<string, object>> currentPage = null; Dictionary<string, object> currentSub = null; ShimOpen(); // Get the next page of 10 items when ExecuteReaderAsync is called. ShimExecuteReader("getDeviceSubscriptionsPaged", null, cmd => { var nextPageSize = allSubs.Count < 10 ? allSubs.Count : 10; currentPage = allSubs.Take(nextPageSize).ToList(); allSubs.RemoveRange(0, nextPageSize); Assert.AreEqual(CommandType.StoredProcedure, cmd.CommandType); }); // Get the next item when ReadAsync is called. ShimRead(() => { if (currentPage.Count > 0) { currentSub = currentPage[0]; currentPage.RemoveAt(0); return true; } else { return false; } }); ShimItemGetString(() => currentSub); var result = await _storageProvider.ListAllSubscriptionsOrderedByDeviceId(LogManager.GetCurrentClassLogger()); Assert.AreEqual(55, result.FindAll(s => s.DeviceId == "test-device" && s.CallbackUrl == "http://test" && s.SubscriptionType == DeviceSubscriptionType.DesiredProperties && s.CreatedAt == testDateTime).Count); _encriptionServiceMock.Verify(p => p.Decrypt(It.IsAny<Logger>(), It.IsAny<string>()), Times.Exactly(55)); } } [Test] [Description("Executes a select to get subscriptions of all types for the device Id received as parameter")] public async Task ListDeviceSubscriptions() { using (ShimsContext.Create()) { _encriptionServiceMock.Invocations.Clear(); ShimOpen(); ShimExecuteReader("SELECT * FROM DeviceSubscriptions WHERE DeviceId = @DeviceId", new Dictionary<string, string>() { { "DeviceId", "test-device" } }); ShimRead(1); var testDateTime = DateTime.Now; ShimItemGetString(GetTestSubscription(testDateTime)); var result = await _storageProvider.ListDeviceSubscriptions(LogManager.GetCurrentClassLogger(), "test-device"); Assert.AreEqual(1, result.Count); Assert.True(result[0].DeviceId == "test-device" && result[0].CallbackUrl == "http://test" && result[0].SubscriptionType == DeviceSubscriptionType.DesiredProperties && result[0].CreatedAt == testDateTime); _encriptionServiceMock.Verify(p => p.Decrypt(It.IsAny<Logger>(), "http://test"), Times.Once()); } } [Test] [Description("Executes a select to get the subscription for the type and device Id received as parameter")] public async Task GetDeviceSubscription() { using (ShimsContext.Create()) { _encriptionServiceMock.Invocations.Clear(); ShimOpen(); ShimExecuteReader("SELECT * FROM DeviceSubscriptions WHERE DeviceId = @DeviceId AND SubscriptionType = @SubscriptionType", new Dictionary<string, string>() { { "DeviceId", "test-device" }, { "SubscriptionType", "DesiredProperties" } }); ShimRead(1); var testDateTime = DateTime.Now; ShimItemGetString(GetTestSubscription(testDateTime)); var result = await _storageProvider.GetDeviceSubscription(LogManager.GetCurrentClassLogger(), "test-device", DeviceSubscriptionType.DesiredProperties, default); Assert.True(result.DeviceId == "test-device" && result.CallbackUrl == "http://test" && result.SubscriptionType == DeviceSubscriptionType.DesiredProperties && result.CreatedAt == testDateTime); _encriptionServiceMock.Verify(p => p.Decrypt(It.IsAny<Logger>(), "http://test"), Times.Once()); } } [Test] [Description("Calls upsertDeviceSubscription to create a subscription of the given type, device Id, and callback URL")] public async Task CreateOrUpdateDeviceSubscription() { using (ShimsContext.Create()) { _encriptionServiceMock.Invocations.Clear(); ShimOpen(); var testDateTime = DateTime.Now; ShimExecuteNonQuery("upsertDeviceSubscription", new Dictionary<string, string>() { { "@DeviceId", "test-device" }, { "@SubscriptionType", "DesiredProperties" }, { "@CallbackUrl", "http://test" } }, cmd => { Assert.AreEqual(CommandType.StoredProcedure, cmd.CommandType); cmd.Parameters.RemoveAt("@CreatedAt"); cmd.Parameters.Add(new SqlParameter("@CreatedAt", testDateTime)); }); var result = await _storageProvider.CreateOrUpdateDeviceSubscription(LogManager.GetCurrentClassLogger(), "test-device", DeviceSubscriptionType.DesiredProperties, "http://test", default); Assert.True(result.DeviceId == "test-device" && result.CallbackUrl == "http://test" && result.SubscriptionType == DeviceSubscriptionType.DesiredProperties && result.CreatedAt == testDateTime); _encriptionServiceMock.Verify(p => p.Encrypt(It.IsAny<Logger>(), "http://test"), Times.Once()); } } [Test] [Description("Executes a delete to remove a single subscription of a given type for the given device Id")] public async Task DeleteDeviceSubscription() { using (ShimsContext.Create()) { ShimOpen(); ShimExecuteNonQuery("DELETE FROM DeviceSubscriptions WHERE DeviceId = @DeviceId AND SubscriptionType = @SubscriptionType", new Dictionary<string, string>() { { "DeviceId", "test-device" }, { "SubscriptionType", "DesiredProperties" } }); await _storageProvider.DeleteDeviceSubscription(LogManager.GetCurrentClassLogger(), "test-device", DeviceSubscriptionType.DesiredProperties, default); } } [Test] [Description("Executes a filtered delete to remove all subscriptions older than a 7 days")] public async Task GcHubCache() { using (ShimsContext.Create()) { ShimOpen(); ShimExecuteNonQuery(null, null, cmd => { var expected = @"DELETE c FROM HubCache c LEFT JOIN DeviceSubscriptions s ON s.DeviceId = c.DeviceId WHERE (s.DeviceId IS NULL) AND (c.RenewedAt < DATEADD(day, -7, GETUTCDATE()))"; Assert.AreEqual(Regex.Replace(expected, @"\s+", " "), Regex.Replace(cmd.CommandText, @"\s+", " ")); }); await _storageProvider.GcHubCache(LogManager.GetCurrentClassLogger()); } } [Test] [Description("Uses a bulk copy and update to renew the timestamp field of all device Ids received as parameter")] public async Task RenewHubCacheEntries() { using (ShimsContext.Create()) { ShimOpen(); ShimExecuteNonQuery("CREATE TABLE #CacheEntriesToRenewTmpTable(DeviceId VARCHAR(255) NOT NULL PRIMARY KEY)"); // Once data is sent to server, check that we remove the temp table. System.Data.SqlClient.Fakes.ShimSqlBulkCopy.AllInstances.WriteToServerAsyncDataTable = (bulkCopy, dt) => { Assert.AreEqual(60, bulkCopy.BulkCopyTimeout); Assert.AreEqual(1000, bulkCopy.BatchSize); Assert.AreEqual("#CacheEntriesToRenewTmpTable", bulkCopy.DestinationTableName); Assert.AreEqual("test-device-1", dt.Rows[0]["DeviceId"]); Assert.AreEqual("test-device-2", dt.Rows[1]["DeviceId"]); Assert.AreEqual("test-device-3", dt.Rows[2]["DeviceId"]); ShimExecuteNonQuery(null, null, cmd => { var expected = @"UPDATE HubCache SET RenewedAt = GETUTCDATE() FROM HubCache INNER JOIN #CacheEntriesToRenewTmpTable Temp ON (Temp.DeviceId = HubCache.DeviceId) DROP TABLE #CacheEntriesToRenewTmpTable"; Assert.AreEqual(Regex.Replace(expected, @"\s+", " "), Regex.Replace(cmd.CommandText, @"\s+", " ")); Assert.AreEqual(300, cmd.CommandTimeout); }); return Task.CompletedTask; }; await _storageProvider.RenewHubCacheEntries(LogManager.GetCurrentClassLogger(), new List<string>() { "test-device-1", "test-device-2", "test-device-3" }); } } [Test] [Description("Calls upsertHubCacheEntry to insert a hub cache entry for a single device")] public async Task AddOrUpdateHubCacheEntry() { using (ShimsContext.Create()) { ShimOpen(); ShimExecuteNonQuery("upsertHubCacheEntry", new Dictionary<string, string>() { { "@DeviceId", "test-device" }, { "@Hub", "test-hub" } }, cmd => { Assert.AreEqual(CommandType.StoredProcedure, cmd.CommandType); }); await _storageProvider.AddOrUpdateHubCacheEntry(LogManager.GetCurrentClassLogger(), "test-device", "test-hub"); } } [Test] [Description("Continuously calls getHubCacheEntriesPaged to get all entries from the HubCache table")] public async Task ListHubCacheEntries() { using (ShimsContext.Create()) { // Return 55 hubs. var testSub = new Dictionary<string, object>() { { "DeviceId", "test-device" }, { "Hub", "test-hub" }, }; var allHubs = Enumerable.Repeat(testSub, 55).ToList(); List<Dictionary<string, object>> currentPage = null; Dictionary<string, object> currentHub = null; ShimOpen(); // Get the next page of 10 items when ExecuteReaderAsync is called. ShimExecuteReader("getHubCacheEntriesPaged", null, cmd => { var nextPageSize = allHubs.Count < 10 ? allHubs.Count : 10; currentPage = allHubs.Take(nextPageSize).ToList(); allHubs.RemoveRange(0, nextPageSize); Assert.AreEqual(CommandType.StoredProcedure, cmd.CommandType); }); // Get the next item when ReadAsync is called. ShimRead(() => { if (currentPage.Count > 0) { currentHub = currentPage[0]; currentPage.RemoveAt(0); return true; } else { return false; } }); ShimItemGetString(() => currentHub); var result = await _storageProvider.ListHubCacheEntries(LogManager.GetCurrentClassLogger()); Assert.AreEqual(55, result.FindAll(s => s.DeviceId == "test-device" && s.Hub == "test-hub").Count); } } [Test] [Description("Executes an arbitrary SQL statement in a connection")] public async Task Exec() { using (ShimsContext.Create()) { ShimOpen(); ShimExecuteNonQuery("my test query"); await _storageProvider.Exec(LogManager.GetCurrentClassLogger(), "my test query"); } } private static void ShimOpen() { System.Data.SqlClient.Fakes.ShimSqlConnection.AllInstances.OpenAsyncCancellationToken = (_, __) => Task.CompletedTask; } private static void ShimExecuteReader(string cmdText, Dictionary<string, string> parameters = null, Action<SqlCommand> onExecute = null) { Func<SqlCommand, Task<SqlDataReader>> shim = (SqlCommand cmd) => { Assert.AreEqual(cmdText, cmd.CommandText); if (parameters != null) { foreach (var entry in parameters) { Assert.AreEqual(entry.Value, cmd.Parameters[entry.Key].Value); } } if (onExecute != null) { onExecute(cmd); } return Task.FromResult<SqlDataReader>(new System.Data.SqlClient.Fakes.ShimSqlDataReader()); }; System.Data.SqlClient.Fakes.ShimSqlCommand.AllInstances.ExecuteReaderAsync = cmd => shim(cmd); System.Data.SqlClient.Fakes.ShimSqlCommand.AllInstances.ExecuteReaderAsyncCancellationToken = (cmd, _) => shim(cmd); } private static void ShimRead(int times) { System.Data.SqlClient.Fakes.ShimSqlDataReader.AllInstances.ReadAsyncCancellationToken = (_, __) => Task.FromResult(times-- > 0); } private static void ShimRead(Func<bool> onRead) { System.Data.SqlClient.Fakes.ShimSqlDataReader.AllInstances.ReadAsyncCancellationToken = (_, __) => Task.FromResult(onRead()); } private static void ShimItemGetString(Dictionary<string, object> item) { System.Data.SqlClient.Fakes.ShimSqlDataReader.AllInstances.ItemGetString = (_, name) => item[name]; } private static void ShimItemGetString(Func<Dictionary<string, object>> getItem) { System.Data.SqlClient.Fakes.ShimSqlDataReader.AllInstances.ItemGetString = (_, name) => getItem()[name]; } private static void ShimExecuteNonQuery(string cmdText = null, Dictionary<string, string> parameters = null, Action<SqlCommand> onExecute = null) { System.Data.SqlClient.Fakes.ShimSqlCommand.AllInstances.ExecuteNonQueryAsyncCancellationToken = (cmd, __) => { if (cmdText != null) { Assert.AreEqual(cmdText, cmd.CommandText); } if (parameters != null) { foreach (var entry in parameters) { Assert.AreEqual(entry.Value, cmd.Parameters[entry.Key].Value); } } if (onExecute != null) { onExecute(cmd); } return Task.FromResult(1); }; } private static Dictionary<string, object> GetTestSubscription(DateTime createdAt) { return new Dictionary<string, object>() { { "DeviceId", "test-device" }, { "SubscriptionType", "DesiredProperties" }, { "CallbackUrl", "http://test" }, { "CreatedAt", createdAt }, }; } } }