CredentialProvider.Microsoft/RequestHandlers/GetAuthenticationCredentialsRequestHandler.cs (129 lines of code) (raw):

// Copyright (c) Microsoft. All rights reserved. // // Licensed under the MIT license. using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using NuGet.Protocol.Plugins; using NuGetCredentialProvider.CredentialProviders; using NuGetCredentialProvider.Logging; using NuGetCredentialProvider.Util; namespace NuGetCredentialProvider.RequestHandlers { /// <summary> /// Handles a <see cref="GetAuthenticationCredentialsRequest"/> and replies with credentials. /// </summary> internal class GetAuthenticationCredentialsRequestHandler : RequestHandlerBase<GetAuthenticationCredentialsRequest, GetAuthenticationCredentialsResponse> { private readonly ICache<Uri, string> cache; private readonly IReadOnlyCollection<ICredentialProvider> credentialProviders; private readonly TimeSpan progressReporterTimeSpan = TimeSpan.FromSeconds(2); /// <summary> /// Initializes a new instance of the <see cref="GetAuthenticationCredentialsRequestHandler"/> class. /// </summary> /// <param name="logger">A <see cref="ILogger"/> to use for logging.</param> /// <param name="credentialProviders">An <see cref="IReadOnlyCollection{ICredentialProviders}"/> containing credential providers.</param> /// <param name="cache">An <see cref="ICache{TKey, TValue}"/> cache to store found credentials.</param> public GetAuthenticationCredentialsRequestHandler(ILogger logger, IReadOnlyCollection<ICredentialProvider> credentialProviders, ICache<Uri, string> cache) : base(logger) { this.credentialProviders = credentialProviders ?? throw new ArgumentNullException(nameof(credentialProviders)); this.cache = cache; } public GetAuthenticationCredentialsRequestHandler(ILogger logger, IReadOnlyCollection<ICredentialProvider> credentialProviders, CancellationToken cancellationToken) : this(logger, credentialProviders, null) { this.cache = GetSessionTokenCache(logger, cancellationToken); } public override async Task<GetAuthenticationCredentialsResponse> HandleRequestAsync(GetAuthenticationCredentialsRequest request) { Logger.Verbose(string.Format(Resources.HandlingAuthRequest, request.Uri.AbsoluteUri, request.IsRetry, request.IsNonInteractive, request.CanShowDialog)); if (request?.Uri == null) { return new GetAuthenticationCredentialsResponse( username: null, password: null, message: Resources.RequestUriNull, authenticationTypes: null, responseCode: MessageResponseCode.Error); } Logger.Verbose(string.Format(Resources.Uri, request.Uri.AbsoluteUri)); foreach (ICredentialProvider credentialProvider in credentialProviders) { if (await credentialProvider.CanProvideCredentialsAsync(request.Uri) == false) { Logger.Verbose(string.Format(Resources.SkippingCredentialProvider, credentialProvider, request.Uri.AbsoluteUri)); continue; } Logger.Verbose(string.Format(Resources.UsingCredentialProvider, credentialProvider, request.Uri.AbsoluteUri)); if (credentialProvider.IsCachable && TryCache(request, out string cachedToken)) { return new GetAuthenticationCredentialsResponse( username: "VssSessionToken", password: cachedToken, message: null, authenticationTypes: new List<string> { "Basic" }, responseCode: MessageResponseCode.Success); } try { GetAuthenticationCredentialsResponse response = await credentialProvider.HandleRequestAsync(request, CancellationToken).ConfigureAwait(continueOnCapturedContext: false); if (response != null && response.ResponseCode == MessageResponseCode.Success) { if (cache != null && credentialProvider.IsCachable) { Logger.Verbose(string.Format(Resources.CachingSessionToken, request.Uri.AbsoluteUri)); cache[request.Uri] = response.Password; } return response; } else if (!string.IsNullOrWhiteSpace(response?.Message)) { Logger.Verbose(response.Message); } } catch (Exception e) { Logger.Error(string.Format(Resources.AcquireSessionTokenFailed, e.ToString())); return new GetAuthenticationCredentialsResponse( username: null, password: null, message: e.Message, authenticationTypes: null, responseCode: MessageResponseCode.Error); } } Logger.Verbose(Resources.CredentialsNotFound); return new GetAuthenticationCredentialsResponse( username: null, password: null, message: null, authenticationTypes: null, responseCode: MessageResponseCode.NotFound); } protected override AutomaticProgressReporter GetProgressReporter(IConnection connection, Message message, CancellationToken cancellationToken) { Logger.Verbose(string.Format(Resources.CreatingProgressReporter, progressReporterTimeSpan.ToString())); return AutomaticProgressReporter.Create(connection, message, progressReporterTimeSpan, cancellationToken); } private static ICache<Uri, string> GetSessionTokenCache(ILogger logger, CancellationToken cancellationToken) { if (EnvUtil.SessionTokenCacheEnabled()) { logger.Verbose(string.Format(Resources.SessionTokenCacheLocation, EnvUtil.SessionTokenCacheLocation)); return new SessionTokenCache(EnvUtil.SessionTokenCacheLocation, logger, cancellationToken); } logger.Verbose(Resources.SessionTokenCacheDisabled); return new NoOpCache<Uri, string>(); } private bool TryCache(GetAuthenticationCredentialsRequest request, out string cachedToken) { cachedToken = null; Logger.Verbose(string.Format(Resources.IsRetry, request.IsRetry)); if (request.IsRetry) { Logger.Verbose(string.Format(Resources.InvalidatingCachedSessionToken, request.Uri.AbsoluteUri)); cache?.Remove(request.Uri); return false; } else if (cache.TryGetValue(request.Uri, out string password)) { Logger.Verbose(string.Format(Resources.FoundCachedSessionToken, request.Uri.AbsoluteUri)); cachedToken = password; return true; } Logger.Verbose(string.Format(Resources.CachedSessionTokenNotFound, request.Uri.AbsoluteUri)); return false; } } }