src/Elastic.Transport/Components/TransportClient/HttpWebRequestInvoker.cs (353 lines of code) (raw):

// Licensed to Elasticsearch B.V under one or more agreements. // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.IO.Compression; using System.Linq; using System.Net; using System.Net.NetworkInformation; using System.Net.Security; using System.Threading; using System.Threading.Tasks; using Elastic.Transport.Diagnostics; using Elastic.Transport.Extensions; namespace Elastic.Transport; /// <summary> /// This provides an <see cref="IRequestInvoker"/> implementation that targets <see cref="HttpWebRequest"/>. /// <para> /// On .NET full framework <see cref="HttpRequestInvoker"/> is an alias to this. /// </para> /// <para/> /// <para>Do NOT use this class directly on .NET Core. <see cref="HttpWebRequest"/> is monkey patched /// over HttpClient and does not reuse its instances of HttpClient /// </para> /// </summary> #if !NETFRAMEWORK [Obsolete("CoreFX HttpWebRequest uses HttpClient under the covers but does not reuse HttpClient instances, do NOT use on .NET core only used as the default on Full Framework")] #endif public class HttpWebRequestInvoker : IRequestInvoker { private string _expectedCertificateFingerprint; static HttpWebRequestInvoker() { //Not available under mono if (!IsMono) HttpWebRequest.DefaultMaximumErrorResponseLength = -1; } /// <summary> /// Create a new instance of the <see cref="HttpWebRequestInvoker"/>. /// </summary> public HttpWebRequestInvoker() : this(new DefaultResponseFactory()) { } internal HttpWebRequestInvoker(ResponseFactory responseFactory) => ResponseFactory = responseFactory; /// <inheritdoc /> public ResponseFactory ResponseFactory { get; } internal static bool IsMono { get; } = Type.GetType("Mono.Runtime") != null; void IDisposable.Dispose() {} /// <inheritdoc cref="IRequestInvoker.Request{TResponse}"/>> public TResponse Request<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData) where TResponse : TransportResponse, new() => RequestCoreAsync<TResponse>(false, endpoint, boundConfiguration, postData).EnsureCompleted(); /// <inheritdoc cref="IRequestInvoker.RequestAsync{TResponse}"/>> public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, CancellationToken cancellationToken = default) where TResponse : TransportResponse, new() => RequestCoreAsync<TResponse>(true, endpoint, boundConfiguration, postData, cancellationToken).AsTask(); private async ValueTask<TResponse> RequestCoreAsync<TResponse>(bool isAsync, Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, CancellationToken cancellationToken = default) where TResponse : TransportResponse, new() { Action unregisterWaitHandle = null; int? statusCode = null; Stream responseStream = null; Exception ex = null; string contentType = null; long contentLength = -1; IDisposable receivedResponse = DiagnosticSources.SingletonDisposable; ReadOnlyDictionary<TcpState, int> tcpStats = null; ReadOnlyDictionary<string, ThreadPoolStatistics> threadPoolStats = null; Dictionary<string, IEnumerable<string>> responseHeaders = null; var beforeTicks = Stopwatch.GetTimestamp(); try { var data = postData; var request = CreateHttpWebRequest(endpoint, boundConfiguration, postData, isAsync); using (cancellationToken.Register(() => request.Abort())) { if (data is not null) { if (isAsync) { var apmGetRequestStreamTask = Task.Factory.FromAsync(request.BeginGetRequestStream, request.EndGetRequestStream, null); unregisterWaitHandle = RegisterApmTaskTimeout(apmGetRequestStreamTask, request, boundConfiguration); using (var stream = await apmGetRequestStreamTask.ConfigureAwait(false)) { if (boundConfiguration.HttpCompression) { using var zipStream = new GZipStream(stream, CompressionMode.Compress); await data.WriteAsync(zipStream, boundConfiguration.ConnectionSettings, boundConfiguration.DisableDirectStreaming, cancellationToken).ConfigureAwait(false); } else await data.WriteAsync(stream, boundConfiguration.ConnectionSettings, boundConfiguration.DisableDirectStreaming, cancellationToken).ConfigureAwait(false); } unregisterWaitHandle?.Invoke(); } else { using var stream = request.GetRequestStream(); if (boundConfiguration.HttpCompression) { using var zipStream = new GZipStream(stream, CompressionMode.Compress); data.Write(zipStream, boundConfiguration.ConnectionSettings, boundConfiguration.DisableDirectStreaming); } else data.Write(stream, boundConfiguration.ConnectionSettings, boundConfiguration.DisableDirectStreaming); } } var prepareRequestMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000); if (prepareRequestMs > OpenTelemetry.MinimumMillisecondsToEmitTimingSpanAttribute && OpenTelemetry.CurrentSpanIsElasticTransportOwnedHasListenersAndAllDataRequested) Activity.Current?.SetTag(OpenTelemetryAttributes.ElasticTransportPrepareRequestMs, prepareRequestMs); //http://msdn.microsoft.com/en-us/library/system.net.httpwebresponse.getresponsestream.aspx //Either the stream or the response object needs to be closed but not both although it won't //throw any errors if both are closed atleast one of them has to be Closed. //Since we expose the stream we let closing the stream determining when to close the connection if (boundConfiguration.EnableTcpStats) tcpStats = TcpStats.GetStates(); if (boundConfiguration.EnableThreadPoolStats) threadPoolStats = ThreadPoolStats.GetStats(); HttpWebResponse httpWebResponse; if (isAsync) { var apmGetResponseTask = Task.Factory.FromAsync(request.BeginGetResponse, r => request.EndGetResponse(r), null); unregisterWaitHandle = RegisterApmTaskTimeout(apmGetResponseTask, request, boundConfiguration); httpWebResponse = (HttpWebResponse)await apmGetResponseTask.ConfigureAwait(false); } else { httpWebResponse = (HttpWebResponse)request.GetResponse(); } receivedResponse = httpWebResponse; HandleResponse(httpWebResponse, out statusCode, out responseStream, out contentType); responseHeaders = ParseHeaders(boundConfiguration, httpWebResponse, responseHeaders); contentLength = httpWebResponse.ContentLength; } } catch (WebException e) { ex = e; if (e.Response is HttpWebResponse httpWebResponse) HandleResponse(httpWebResponse, out statusCode, out responseStream, out contentType); } finally { unregisterWaitHandle?.Invoke(); } try { TResponse response; if (isAsync) response = await ResponseFactory.CreateAsync<TResponse> (endpoint, boundConfiguration, postData, ex, statusCode, responseHeaders, responseStream, contentType, contentLength, threadPoolStats, tcpStats, cancellationToken) .ConfigureAwait(false); else response = ResponseFactory.Create<TResponse> (endpoint, boundConfiguration, postData, ex, statusCode, responseHeaders, responseStream, contentType, contentLength, threadPoolStats, tcpStats); // Unless indicated otherwise by the TransportResponse, we've now handled the response stream, so we can dispose of the HttpResponseMessage // to release the connection. In cases, where the derived response works directly on the stream, it can be left open and additional IDisposable // resources can be linked such that their disposal is deferred. if (response.LeaveOpen) { response.LinkedDisposables = [receivedResponse, responseStream]; } else { responseStream?.Dispose(); receivedResponse?.Dispose(); } RequestInvokerHelpers.SetOtelAttributes(boundConfiguration, response); return response; } catch { // if there's an exception, ensure we always release the stream and response so that the connection is freed. responseStream?.Dispose(); receivedResponse?.Dispose(); throw; } } private static Dictionary<string, IEnumerable<string>> ParseHeaders(BoundConfiguration boundConfiguration, HttpWebResponse responseMessage, Dictionary<string, IEnumerable<string>> responseHeaders) { if (!responseMessage.SupportsHeaders && !responseMessage.Headers.HasKeys()) return null; var defaultHeadersForProduct = boundConfiguration.ConnectionSettings.ProductRegistration.DefaultHeadersToParse(); foreach (var headerToParse in defaultHeadersForProduct) { if (responseMessage.Headers.AllKeys.Contains(headerToParse, StringComparer.OrdinalIgnoreCase)) { responseHeaders ??= new Dictionary<string, IEnumerable<string>>(); responseHeaders.Add(headerToParse, responseMessage.Headers.GetValues(headerToParse)); } } if (boundConfiguration.ParseAllHeaders) { foreach (var key in responseMessage.Headers.AllKeys) { responseHeaders ??= new Dictionary<string, IEnumerable<string>>(); responseHeaders.Add(key, responseMessage.Headers.GetValues(key)); } } else if (boundConfiguration.ResponseHeadersToParse is { Count: > 0 }) { foreach (var headerToParse in boundConfiguration.ResponseHeadersToParse) { if (responseMessage.Headers.AllKeys.Contains(headerToParse, StringComparer.OrdinalIgnoreCase)) { responseHeaders ??= new Dictionary<string, IEnumerable<string>>(); responseHeaders.Add(headerToParse, responseMessage.Headers.GetValues(headerToParse)); } } } return responseHeaders; } /// <summary> /// Allows subclasses to modify the <see cref="HttpWebRequest"/> instance that is going to be used for the API call /// </summary> /// <param name="endpoint">An instance of <see cref="Endpoint"/> describing where to call out to</param> /// <param name="boundConfiguration">An instance of <see cref="BoundConfiguration"/> describing how to call out to</param> /// <param name="postData">Optional data to send over the wire</param> /// <param name="isAsync"></param> protected virtual HttpWebRequest CreateHttpWebRequest(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, bool isAsync) { var request = CreateWebRequest(endpoint, boundConfiguration, postData, isAsync); SetAuthenticationIfNeeded(endpoint, boundConfiguration, request); SetProxyIfNeeded(request, boundConfiguration); SetServerCertificateValidationCallBackIfNeeded(request, boundConfiguration); SetClientCertificates(request, boundConfiguration); AlterServicePoint(request.ServicePoint, boundConfiguration); return request; } /// <summary> Hook for subclasses to set additional client certificates on <paramref name="request"/> </summary> protected virtual void SetClientCertificates(HttpWebRequest request, BoundConfiguration boundConfiguration) { if (boundConfiguration.ClientCertificates != null) request.ClientCertificates.AddRange(boundConfiguration.ClientCertificates); } private string ComparableFingerprint(string fingerprint) { var finalFingerprint = fingerprint; if (fingerprint.Contains(':')) { finalFingerprint = fingerprint.Replace(":", string.Empty); } else if (fingerprint.Contains('-')) { finalFingerprint = fingerprint.Replace("-", string.Empty); } return finalFingerprint; } /// <summary> Hook for subclasses override the certificate validation on <paramref name="request"/> </summary> protected virtual void SetServerCertificateValidationCallBackIfNeeded(HttpWebRequest request, BoundConfiguration boundConfiguration) { var callback = boundConfiguration?.ConnectionSettings?.ServerCertificateValidationCallback; #if !__MonoCS__ //Only assign if one is defined on connection settings and a subclass has not already set one if (callback != null && request.ServerCertificateValidationCallback == null) { request.ServerCertificateValidationCallback = new RemoteCertificateValidationCallback(callback); } else if (!string.IsNullOrEmpty(boundConfiguration.ConnectionSettings.CertificateFingerprint)) { request.ServerCertificateValidationCallback = new RemoteCertificateValidationCallback((request, certificate, chain, policyErrors) => { if (certificate is null && chain is null) return false; // The "cleaned", expected fingerprint is cached to avoid repeated cost of converting it to a comparable form. _expectedCertificateFingerprint ??= CertificateHelpers.ComparableFingerprint(boundConfiguration.ConnectionSettings.CertificateFingerprint); // If there is a chain, check each certificate up to the root if (chain is not null) { foreach (var element in chain.ChainElements) { if (CertificateHelpers.ValidateCertificateFingerprint(element.Certificate, _expectedCertificateFingerprint)) return true; } } // Otherwise, check the certificate return CertificateHelpers.ValidateCertificateFingerprint(certificate, _expectedCertificateFingerprint); }); } #else if (callback != null) throw new Exception("Mono misses ServerCertificateValidationCallback on HttpWebRequest"); #endif } private static HttpWebRequest CreateWebRequest(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, bool isAsync) { var request = (HttpWebRequest)WebRequest.Create(endpoint.Uri); request.Accept = boundConfiguration.Accept; request.ContentType = boundConfiguration.ContentType; #if NETFRAMEWORK // on netstandard/netcoreapp2.0 this throws argument exception request.MaximumResponseHeadersLength = -1; #endif request.Pipelined = boundConfiguration.HttpPipeliningEnabled; if (boundConfiguration.TransferEncodingChunked) request.SendChunked = true; if (boundConfiguration.HttpCompression) { request.AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate; request.Headers.Add("Accept-Encoding", "gzip,deflate"); request.Headers.Add("Content-Encoding", "gzip"); } var userAgent = boundConfiguration.UserAgent?.ToString(); if (!string.IsNullOrWhiteSpace(userAgent)) request.UserAgent = userAgent; if (!string.IsNullOrWhiteSpace(boundConfiguration.RunAs)) request.Headers.Add(BoundConfiguration.RunAsSecurityHeader, boundConfiguration.RunAs); if (boundConfiguration.Headers != null && boundConfiguration.Headers.HasKeys()) request.Headers.Add(boundConfiguration.Headers); if (boundConfiguration.MetaHeaderProvider is not null) { foreach (var producer in boundConfiguration.MetaHeaderProvider.Producers) { var value = producer.ProduceHeaderValue(boundConfiguration, isAsync); if (!string.IsNullOrEmpty(value)) request.Headers.Add(producer.HeaderName, value); } } var timeout = (int)boundConfiguration.RequestTimeout.TotalMilliseconds; request.Timeout = timeout; request.ReadWriteTimeout = timeout; //WebRequest won't send Content-Length: 0 for empty bodies //which goes against RFC's and might break i.e IIS when used as a proxy. //see: https://github.com/elastic/elasticsearch-net/issues/562 var m = endpoint.Method.GetStringValue(); request.Method = m; if (m != "HEAD" && m != "GET" && postData == null) request.ContentLength = 0; return request; } /// <summary> Hook for subclasses override <see cref="ServicePoint"/> behavior</summary> protected virtual void AlterServicePoint(ServicePoint requestServicePoint, BoundConfiguration boundConfiguration) { requestServicePoint.UseNagleAlgorithm = false; requestServicePoint.Expect100Continue = false; requestServicePoint.ConnectionLeaseTimeout = (int)boundConfiguration.DnsRefreshTimeout.TotalMilliseconds; if (boundConfiguration.ConnectionSettings.ConnectionLimit > 0) requestServicePoint.ConnectionLimit = boundConfiguration.ConnectionSettings.ConnectionLimit; //looking at http://referencesource.microsoft.com/#System/net/System/Net/ServicePoint.cs //this method only sets internal values and wont actually cause timers and such to be reset //So it should be idempotent if called with the same parameters requestServicePoint.SetTcpKeepAlive(true, boundConfiguration.KeepAliveTime, boundConfiguration.KeepAliveInterval); } /// <summary> Hook for subclasses to set proxy on <paramref name="request"/> </summary> protected virtual void SetProxyIfNeeded(HttpWebRequest request, BoundConfiguration boundConfiguration) { if (!string.IsNullOrWhiteSpace(boundConfiguration.ProxyAddress)) { var uri = new Uri(boundConfiguration.ProxyAddress); var proxy = new WebProxy(uri); var credentials = new NetworkCredential(boundConfiguration.ProxyUsername, boundConfiguration.ProxyPassword); proxy.Credentials = credentials; request.Proxy = proxy; } if (boundConfiguration.DisableAutomaticProxyDetection) request.Proxy = null!; } /// <summary> Hook for subclasses to set authentication on <paramref name="request"/></summary> protected virtual void SetAuthenticationIfNeeded(Endpoint endpoint, BoundConfiguration boundConfiguration, HttpWebRequest request) { //If user manually specifies an Authorization Header give it preference if (boundConfiguration.Headers is not null && boundConfiguration.Headers.HasKeys() && boundConfiguration.Headers.AllKeys.Contains("Authorization")) { var header = boundConfiguration.Headers["Authorization"]; request.Headers["Authorization"] = header; return; } SetBasicAuthenticationIfNeeded(endpoint, boundConfiguration, request); } private static void SetBasicAuthenticationIfNeeded(Endpoint endpoint, BoundConfiguration boundConfiguration, HttpWebRequest request) { // Basic auth credentials take the following precedence (highest -> lowest): // 1 - Specified on the request (highest precedence) // 2 - Specified at the global TransportClientSettings level // 3 - Specified with the URI (lowest precedence) // Basic auth credentials take the following precedence (highest -> lowest): // 1 - Specified with the URI (highest precedence) // 2 - Specified on the request // 3 - Specified at the global TransportClientSettings level (lowest precedence) string parameters = null; string scheme = null; if (!endpoint.Uri.UserInfo.IsNullOrEmpty()) { parameters = BasicAuthentication.GetBase64String(Uri.UnescapeDataString(endpoint.Uri.UserInfo)); scheme = BasicAuthentication.BasicAuthenticationScheme; } else if (boundConfiguration.AuthenticationHeader != null && boundConfiguration.AuthenticationHeader.TryGetAuthorizationParameters(out var v)) { parameters = v; scheme = boundConfiguration.AuthenticationHeader.AuthScheme; } if (parameters.IsNullOrEmpty()) return; request.Headers["Authorization"] = $"{scheme} {parameters}"; } /// <summary> /// Registers an APM async task cancellation on the threadpool /// </summary> /// <returns>An unregister action that can be used to remove the waithandle prematurely</returns> private static Action RegisterApmTaskTimeout(IAsyncResult result, WebRequest request, BoundConfiguration boundConfiguration) { var waitHandle = result.AsyncWaitHandle; var registeredWaitHandle = ThreadPool.RegisterWaitForSingleObject(waitHandle, TimeoutCallback, request, boundConfiguration.RequestTimeout, true); return () => registeredWaitHandle.Unregister(waitHandle); } private static void TimeoutCallback(object state, bool timedOut) { if (!timedOut) return; (state as WebRequest)?.Abort(); } private static void HandleResponse(HttpWebResponse response, out int? statusCode, out Stream responseStream, out string contentType) { statusCode = (int)response.StatusCode; responseStream = response.GetResponseStream(); contentType = response.ContentType; if (responseStream == null || responseStream == Stream.Null) response.Dispose(); } }