src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs (306 lines of code) (raw):

// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.Azure.SignalR; internal abstract class ServiceLifetimeManagerBase<THub> : HubLifetimeManager<THub> where THub : Hub { protected const string NullOrEmptyStringErrorMessage = "Argument cannot be null or empty."; protected const string TtlOutOfRangeErrorMessage = "Ttl cannot be less than 0."; protected readonly IServiceConnectionManager<THub> ServiceConnectionContainer; protected ILogger Logger { get; set; } private readonly DefaultHubMessageSerializer _messageSerializer; public ServiceLifetimeManagerBase(IServiceConnectionManager<THub> serviceConnectionManager, IHubProtocolResolver protocolResolver, IOptions<HubOptions> globalHubOptions, IOptions<HubOptions<THub>> hubOptions, ILogger logger) { Logger = logger ?? throw new ArgumentNullException(nameof(logger)); ServiceConnectionContainer = serviceConnectionManager; _messageSerializer = new DefaultHubMessageSerializer(protocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols); } public override Task OnConnectedAsync(HubConnectionContext connection) { return Task.CompletedTask; } public override Task OnDisconnectedAsync(HubConnectionContext connection) { return Task.CompletedTask; } public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new BroadcastDataMessage(null, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToBroadcastMessage(Logger, message); } return WriteAsync(message); } public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds, CancellationToken cancellationToken = default) { if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new BroadcastDataMessage(excludedIds, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToBroadcastMessage(Logger, message); } return WriteAsync(message); } public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new MultiConnectionDataMessage(new[] { connectionId }, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToSendMessageToConnections(Logger, message); } return WriteAsync(message); } public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionIds)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionIds)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } if (connectionIds.Count == 0) { return Task.CompletedTask; } var message = AppendMessageTracingId(new MultiConnectionDataMessage(connectionIds, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToSendMessageToConnections(Logger, message); } return WriteAsync(message); } public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(groupName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(groupName)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new GroupBroadcastDataMessage(groupName, null, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToBroadcastMessageToGroup(Logger, message); } return WriteAsync(message); } public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(groupNames)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(groupNames)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } if (groupNames.Count == 0) { return Task.CompletedTask; } var message = AppendMessageTracingId(new MultiGroupBroadcastDataMessage(groupNames, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToBroadcastMessageToGroups(Logger, message); } // Send this message from a random service connection because this message involves of multiple groups. // Unless we send message for each group one by one, we can not guarantee the message order for all groups. return WriteAsync(message); } public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList<string> excludedIds, CancellationToken cancellationToken = default) { if (IsInvalidArgument(groupName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(groupName)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new GroupBroadcastDataMessage(groupName, excludedIds, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToBroadcastMessageToGroup(Logger, message); } return WriteAsync(message); } public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(userId)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(userId)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } var message = AppendMessageTracingId(new UserDataMessage(userId, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToSendMessageToUser(Logger, message); } return WriteAsync(message); } public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(userIds)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(userIds)); } if (IsInvalidArgument(methodName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName)); } if (userIds.Count == 0) { return Task.CompletedTask; } var message = AppendMessageTracingId(new MultiUserDataMessage(userIds, SerializeAllProtocols(methodName, args))); if (message.TracingId != null) { MessageLog.StartToSendMessageToUsers(Logger, message); } return WriteAsync(message); } public override Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); } if (IsInvalidArgument(groupName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(groupName)); } var message = AppendMessageTracingId(new JoinGroupWithAckMessage(connectionId, groupName)); if (message.TracingId != null) { MessageLog.StartToAddConnectionToGroup(Logger, message); } return WriteAckableMessageAsync(message, cancellationToken); } public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); } if (IsInvalidArgument(groupName)) { throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(groupName)); } var message = AppendMessageTracingId(new LeaveGroupWithAckMessage(connectionId, groupName)); if (message.TracingId != null) { MessageLog.StartToRemoveConnectionFromGroup(Logger, message); } return WriteAckableMessageAsync(message, cancellationToken); } protected Task WriteAsync<T>(T message) where T : ServiceMessage, IMessageWithTracingId => WriteCoreAsync(message, m => ServiceConnectionContainer.WriteAsync(message)); protected Task<bool> WriteAckableMessageAsync<T>(T message, CancellationToken cancellation) where T : ServiceMessage, IMessageWithTracingId => WriteAckableCoreAsync(message, m => ServiceConnectionContainer.WriteAckableMessageAsync(m, cancellation)); protected static bool IsInvalidArgument(string value) { return string.IsNullOrEmpty(value); } protected static bool IsInvalidArgument(IReadOnlyList<object> list) { return list == null; } protected IDictionary<string, ReadOnlyMemory<byte>> SerializeAllProtocols(string method, object[] args, string invocationId = null) { InvocationMessage message; if (invocationId == null) { message = new InvocationMessage(method, args); } else { message = new InvocationMessage(invocationId, method, args); } var serializedHubMessages = _messageSerializer.SerializeMessage(message); var payloads = new ArrayDictionary<string, ReadOnlyMemory<byte>>(serializedHubMessages.Count); foreach (var serializedMessage in serializedHubMessages) { payloads.Add(serializedMessage.ProtocolName, serializedMessage.Serialized); } return payloads; } protected IDictionary<string, ReadOnlyMemory<byte>> SerializeAllProtocols(HubMessage message) { var serializedHubMessages = _messageSerializer.SerializeMessage(message); var payloads = new ArrayDictionary<string, ReadOnlyMemory<byte>>(serializedHubMessages.Count); foreach (var serializedMessage in serializedHubMessages) { payloads.Add(serializedMessage.ProtocolName, serializedMessage.Serialized); } return payloads; } protected ReadOnlyMemory<byte> SerializeProtocol(string protocol, string method, object[] args) => _messageSerializer.SerializeMessage(protocol, new InvocationMessage(method, args)); protected ReadOnlyMemory<byte> SerializeCompletionMessage(CompletionMessage message, string protocol) => _messageSerializer.SerializeMessage(protocol, message); protected virtual T AppendMessageTracingId<T>(T message) where T : ServiceMessage, IMessageWithTracingId { return message.WithTracingId(); } private async Task WriteCoreAsync<T>(T message, Func<T, Task> task) where T : ServiceMessage, IMessageWithTracingId { try { await task(message); } catch (Exception ex) { MessageLog.FailedToSendMessage(Logger, message, ex); throw; } if (message.TracingId != null) { MessageLog.SucceededToSendMessage(Logger, message); } } private async Task<bool> WriteAckableCoreAsync<T>(T message, Func<T, Task<bool>> task) where T : ServiceMessage, IMessageWithTracingId { try { var result = await task(message); if (message.TracingId != null) { MessageLog.SucceededToSendMessage(Logger, message); } return result; } catch (Exception ex) { MessageLog.FailedToSendMessage(Logger, message, ex); throw; } } }