src/Transport/Tcp/TcpTransportFactory.cs (347 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ using System; using System.Collections.Specialized; using System.Net; using System.Net.Sockets; using Apache.NMS.ActiveMQ.Util; using Apache.NMS.ActiveMQ.OpenWire; using Apache.NMS.Util; namespace Apache.NMS.ActiveMQ.Transport.Tcp { [ActiveMQTransportFactory("tcp")] public class TcpTransportFactory : ITransportFactory { public TcpTransportFactory() { } #region Properties private bool useLogging = false; public bool UseLogging { get { return useLogging; } set { useLogging = value; } } /// <summary> /// Should the Inactivity Monitor be enabled on this Transport. /// </summary> private bool useInactivityMonitor = true; public bool UseInactivityMonitor { get { return this.useInactivityMonitor; } set { this.useInactivityMonitor = value; } } /// <summary> /// Size in bytes of the receive buffer. /// </summary> private int receiveBufferSize = 8192; public int ReceiveBufferSize { get { return receiveBufferSize; } set { receiveBufferSize = value; } } /// <summary> /// Size in bytes of send buffer. /// </summary> private int sendBufferSize = 8192; public int SendBufferSize { get { return sendBufferSize; } set { sendBufferSize = value; } } /// <summary> /// The time-out value, in milliseconds. The default value is 0, which indicates /// an infinite time-out period. Specifying -1 also indicates an infinite time-out period. /// </summary> private int receiveTimeout = 0; public int ReceiveTimeout { get { return receiveTimeout; } set { receiveTimeout = value; } } /// <summary> /// The time-out value, in milliseconds. If you set the property with a value between 1 and 499, /// the value will be changed to 500. The default value is 0, which indicates an infinite /// time-out period. Specifying -1 also indicates an infinite time-out period. /// </summary> private int sendTimeout = 0; public int SendTimeout { get { return sendTimeout; } set { sendTimeout = value; } } private int connectTimeout = 30000; public int ConnectTimeout { get { return connectTimeout; } set { this.connectTimeout = value; } } #endregion #region ITransportFactory Members public ITransport CompositeConnect(Uri location) { // Extract query parameters from broker Uri StringDictionary map = URISupport.ParseQuery(location.Query); // Set transport. properties on this (the factory) URISupport.SetProperties(this, map, "transport."); // See if there is a local address and port specified string localAddress = null; int localPort = -1; if(!String.IsNullOrEmpty(location.AbsolutePath) && !location.AbsolutePath.Equals("/")) { int index = location.AbsolutePath.IndexOf(':'); try { localPort = Int16.Parse(location.AbsolutePath.Substring(index + 1)); localAddress = location.AbsolutePath.Substring(1, index - 1); Tracer.DebugFormat("Binding Socket to {0} on port: {1}", localAddress, localPort); } catch { Tracer.Warn("Invalid Port value on URI for local bind option, ignoring."); } } Tracer.Debug("Opening socket to: " + location.Host + " on port: " + location.Port); Socket socket = DoConnect(location.Host, location.Port, localAddress, localPort ); #if !NETCF socket.ReceiveBufferSize = ReceiveBufferSize; socket.SendBufferSize = SendBufferSize; socket.ReceiveTimeout = ReceiveTimeout; socket.SendTimeout = SendTimeout; #endif OpenWireFormat wireformat = new OpenWireFormat(); // Set wireformat. properties on the wireformat owned by the tcpTransport URISupport.SetProperties(wireformat.PreferredWireFormatInfo, map, "wireFormat."); ITransport transport = DoCreateTransport(location, socket, wireformat); wireformat.Transport = transport; if(UseLogging) { transport = new LoggingTransport(transport); } if(UseInactivityMonitor) { transport = new InactivityMonitor(transport); } transport = new WireFormatNegotiator(transport, wireformat); return transport; } public ITransport CreateTransport(Uri location) { ITransport transport = CompositeConnect(location); transport = new MutexTransport(transport); transport = new ResponseCorrelator(transport); return transport; } #endregion /// <summary> /// Override in a subclass to create the specific type of transport that is /// being implemented. /// </summary> protected virtual ITransport DoCreateTransport(Uri location, Socket socket, IWireFormat wireFormat ) { TcpTransport transport = new TcpTransport(location, socket, wireFormat); // Apply the buffer sizes to the transport also so that it can buffer above the // TCP level which can eagerly send causing sparse packets. transport.SendBufferSize = SendBufferSize; transport.ReceiveBufferSize = ReceiveBufferSize; return transport; } // DISCUSSION: Caching host entries may not be the best strategy when using the // failover protocol. The failover protocol needs to be very dynamic when looking // up hostnames at runtime. If old hostname->IP mappings are kept around, this may // lead to runtime failures that could have been avoided by dynamically looking up // the new hostname IP. #if CACHE_HOSTENTRIES private static IDictionary<string, IPHostEntry> CachedIPHostEntries = new Dictionary<string, IPHostEntry>(); private static readonly object _syncLock = new object(); #endif public static IPHostEntry GetIPHostEntry(string host) { IPHostEntry ipEntry; #if CACHE_HOSTENTRIES string hostUpperName = host.ToUpper(); lock (_syncLock) { if (!CachedIPHostEntries.TryGetValue(hostUpperName, out ipEntry)) { try { ipEntry = Dns.GetHostEntry(hostUpperName); CachedIPHostEntries.Add(hostUpperName, ipEntry); } catch { ipEntry = null; } } } #else try { ipEntry = Dns.GetHostEntry(host); } catch { ipEntry = null; } #endif return ipEntry; } private Socket TryConnectSocket(IPAddress address, int port, string localAddress, int localPort) { if(null != address) { try { Socket socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp); if(null != socket) { if(!String.IsNullOrEmpty(localAddress)) { DoBind(socket, localAddress, localPort); } IAsyncResult result = socket.BeginConnect(new IPEndPoint(address, port), null, null); result.AsyncWaitHandle.WaitOne(ConnectTimeout, true); if(!socket.Connected) { socket.Close(); } else { return socket; } } } catch { } } return null; } public static bool TryParseIPAddress(string host, out IPAddress ipaddress) { #if !NETCF return IPAddress.TryParse(host, out ipaddress); #else try { ipaddress = IPAddress.Parse(host); } catch { ipaddress = null; } return (null != ipaddress); #endif } public static IPAddress GetIPAddress(string hostname, AddressFamily addressFamily) { IPAddress ipaddress = null; IPHostEntry hostEntry = GetIPHostEntry(hostname); if(null != hostEntry) { ipaddress = GetIPAddress(hostEntry, addressFamily); } return ipaddress; } public static IPAddress GetIPAddress(IPHostEntry hostEntry, AddressFamily addressFamily) { if(null != hostEntry) { foreach(IPAddress address in hostEntry.AddressList) { if(address.AddressFamily == addressFamily) { return address; } } } return null; } protected Socket DoConnect(string host, int port, string localAddress, int localPort) { Socket socket = null; IPAddress ipaddress; try { if(TryParseIPAddress(host, out ipaddress)) { socket = TryConnectSocket(ipaddress, port, localAddress, localPort); } else { // Looping through the AddressList allows different type of connections to be tried // (IPv6, IPv4 and whatever else may be available). IPHostEntry hostEntry = GetIPHostEntry(host); if(null != hostEntry) { // Prefer IPv6 first. ipaddress = GetIPAddress(hostEntry, AddressFamily.InterNetworkV6); socket = TryConnectSocket(ipaddress, port, localAddress, localPort); if(null == socket) { // Try IPv4 next. ipaddress = GetIPAddress(hostEntry, AddressFamily.InterNetwork); socket = TryConnectSocket(ipaddress, port, localAddress, localPort); if(null == socket) { // Try whatever else there is. foreach(IPAddress address in hostEntry.AddressList) { if(AddressFamily.InterNetworkV6 == address.AddressFamily || AddressFamily.InterNetwork == address.AddressFamily) { // Already tried these protocols. continue; } socket = TryConnectSocket(ipaddress, port, localAddress, localPort); if(null != socket) { ipaddress = address; break; } } } } } } if(null == socket) { const int RTSSL_HANDSHAKE_FAILURE = -2; throw new SocketException(RTSSL_HANDSHAKE_FAILURE); } } catch(Exception ex) { throw new NMSConnectionException(String.Format("Error connecting to {0}:{1}.", host, port), ex); } Tracer.DebugFormat("Connected to {0}:{1} using {2} protocol.", host, port, ipaddress.AddressFamily.ToString()); return socket; } protected void DoBind(Socket socket, string host, int port) { IPAddress ipaddress; try { if(TryParseIPAddress(host, out ipaddress)) { TryBindSocket(socket, ipaddress, port); } else { // Looping through the AddressList allows different type of connections to be tried // (IPv6, IPv4 and whatever else may be available). IPHostEntry hostEntry = GetIPHostEntry(host); if(null != hostEntry) { // Prefer IPv6 first. ipaddress = GetIPAddress(hostEntry, AddressFamily.InterNetworkV6); if(!TryBindSocket(socket, ipaddress, port)) { // Try IPv4 next. ipaddress = GetIPAddress(hostEntry, AddressFamily.InterNetwork); if(!TryBindSocket(socket, ipaddress, port)) { // Try whatever else there is. foreach(IPAddress address in hostEntry.AddressList) { if(AddressFamily.InterNetworkV6 == address.AddressFamily || AddressFamily.InterNetwork == address.AddressFamily) { // Already tried these protocols. continue; } if(TryBindSocket(socket, ipaddress, port)) { ipaddress = address; break; } } } } } } if(!socket.IsBound) { throw new SocketException(); } } catch(Exception ex) { throw new NMSConnectionException(String.Format("Error binding to {0}:{1}.", host, port), ex); } Tracer.DebugFormat("Bound to {0}:{1} using.", host, port); } private bool TryBindSocket(Socket socket, IPAddress address, int port) { if(null != socket && null != address) { try { socket.Bind(new IPEndPoint(address, port)); if(socket.IsBound) { return true; } } catch { } } return false; } } }