src/Proton.TestPeer/Network/PeerTcpClient.cs (138 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using Apache.Qpid.Proton.Test.Driver.Utilities;
using Microsoft.Extensions.Logging;
namespace Apache.Qpid.Proton.Test.Driver.Network
{
public sealed class PeerTcpClient
{
private readonly ILoggerFactory loggerFactory;
private readonly ILogger<PeerTcpClient> logger;
private readonly ProtonTestClientOptions options;
private string address;
/// <summary>
/// Create a new peer Tcp client instance that can be used to connect to a remote.
/// </summary>
public PeerTcpClient(ProtonTestClientOptions options, in ILoggerFactory loggerFactory)
{
this.loggerFactory = loggerFactory;
this.logger = loggerFactory?.CreateLogger<PeerTcpClient>();
this.options = options;
}
public PeerTcpTransport Connect(string address, int port)
{
this.address = address;
if (IPAddress.TryParse(address, out IPAddress parsedAddress))
{
if (parsedAddress.AddressFamily == AddressFamily.InterNetwork)
{
return Connect(new IPEndPoint(parsedAddress, port));
}
}
IPHostEntry entry = Dns.GetHostEntry(address);
foreach (IPAddress ipAddress in entry.AddressList)
{
if (ipAddress.AddressFamily == AddressFamily.InterNetwork)
{
return Connect(new IPEndPoint(ipAddress, port));
}
}
throw new InvalidOperationException("Could not resolve the address into an IPV4 IP Address");
}
public PeerTcpTransport Connect(IPEndPoint endpoint)
{
Statics.RequireNonNull(endpoint, "Cannot connect when the end point given is null");
Socket clientSocket = new(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
try
{
clientSocket.Connect(endpoint);
// Configure socket options from configuration options
clientSocket.SendBufferSize = options.SendBufferSize;
clientSocket.ReceiveBufferSize = options.ReceiveBufferSize;
clientSocket.NoDelay = options.TcpNoDelay;
clientSocket.LingerState = new LingerOption(options.SoLinger > 0, (int)options.SoLinger);
clientSocket.SendTimeout = (int)options.SendTimeout;
clientSocket.ReceiveTimeout = (int)options.ReceiveTimeout;
}
catch (Exception)
{
try
{
clientSocket.Close();
}
catch (Exception)
{
}
throw;
}
Stream ioStream = new NetworkStream(clientSocket);
if (options.SslEnabled)
{
ioStream = AuthenticateAsClient(ioStream);
}
return new PeerTcpTransport(loggerFactory, PeerTransportRole.Client, clientSocket, ioStream);
}
private Stream AuthenticateAsClient(Stream ioStream)
{
SslStream sslStream = new(ioStream, false, RemoteCertificateValidationCallback, LocalCertificateSelectionCallback);
sslStream.AuthenticateAsClient(address, options.ClientCertificates, options.CheckForCertificateRevocation);
return sslStream;
}
private bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
if (sslPolicyErrors == SslPolicyErrors.None)
{
return true;
}
bool validated = true;
bool remoteCertificateNotAvailable = sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable);
bool remoteCertificateNameMismatch = sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch);
bool remoteCertificateChainErrors = sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors);
if (remoteCertificateNotAvailable &&
!options.AllowedSslPolicyErrorsOverride.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable))
{
logger.LogTrace("Server certificate authentication failed due lack of provided certificate: {0}", sslPolicyErrors);
validated = false;
}
if (remoteCertificateChainErrors &&
!options.AllowedSslPolicyErrorsOverride.HasFlag(SslPolicyErrors.RemoteCertificateChainErrors))
{
logger.LogTrace("Server certificate authentication failed due certificate chain error: {0}", sslPolicyErrors);
validated = false;
}
if (remoteCertificateNameMismatch && options.VerifyHost &&
!options.AllowedSslPolicyErrorsOverride.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch))
{
logger.LogTrace("Server certificate authentication failed due remote certificate name mismatch: {0}", sslPolicyErrors);
validated = false;
}
if (!validated)
{
logger.LogDebug("Server authentication had SSL policy error(s): {0}", sslPolicyErrors);
}
return validated;
}
public static X509Certificate LocalCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
{
X509Certificate result = null;
if (acceptableIssuers != null && acceptableIssuers.Length > 0 &&
localCertificates != null && localCertificates.Count > 0)
{
foreach (X509Certificate certificate in localCertificates)
{
string issuer = certificate.Issuer;
if (Array.IndexOf(acceptableIssuers, issuer) != -1)
{
result = certificate;
break;
}
}
}
if (result == null && localCertificates != null && localCertificates.Count > 0)
{
result = localCertificates[0];
}
return result;
}
}
}