host/cxpslib/connection.h (429 lines of code) (raw):

/// \file connection.h /// /// \brief tcp connection handling classes that can be used for both client and server side /// #ifndef CONNECTION_H #define CONNECTION_H #include <cstddef> #include <vector> #include <boost/asio.hpp> #include <boost/asio/ssl.hpp> #include <boost/function.hpp> #include <boost/shared_ptr.hpp> #include <boost/lexical_cast.hpp> #include "errorexception.h" #include "cxps.h" #include "cxpssslcontext.h" #include "setsockettimeoutsminor.h" /// \brief maximun size of read buffers #define MAX_READ_BUFFER 1024 /// \brief abstract connection for managing tcp connections class ConnectionAbc { public: /// shared pointer type for holding the actual instantiated connection object typedef boost::shared_ptr<ConnectionAbc> ptr; typedef std::vector<boost::asio::const_buffer> writeBuffer_t; // if any of these callback types fail to compile consult boost.function documentation // about possible work arounds for the failing compiler /// async read callback function type typedef boost::function<void (boost::system::error_code const & error, size_t bytesTransferred)> readHandler_t; /// async write callback function type typedef boost::function<void (boost::system::error_code const & error, size_t bytesTransferred)> writeHandler_t; /// async ssl hand shake callback function type typedef boost::function<void (boost::system::error_code const & error)> handshakeHandler_t; /// async connect callback function type typedef boost::function<void (boost::system::error_code const & error)> connectHandler_t; explicit ConnectionAbc() : m_timedOut(false), m_sslShutdownNeeded(false) {} virtual ~ConnectionAbc() { } /// \brief returns eof state /// /// This function needs to be implemented by a subclass. It should return the current /// eof state /// /// \returns /// \li true: eof /// \li false: not eof virtual bool eof() = 0; /// \brief pure virtual function that should synchronously write N bytes to the socket /// /// This function needs to be implemented by a subclass. It should block until it has /// written N bytes to the socket or an error occurs. /// /// \return /// \li \c length on succes or throws an exception on error virtual std::size_t writeN(char const * ptr, std::size_t length) = 0; /// \brief pure virtual function that should synchronously write N bytes to the socket /// /// This function needs to be implemented by a subclass. It should block until it has /// written all the buffer data to the socket or an error occurs. /// /// \return /// \li \c length on succes or throws an exception on error virtual std::size_t writeN(ConnectionAbc::writeBuffer_t const& buffer) = 0; /// \brief pure virtual function that should start an asynchronous write to the socket /// /// This function needs to be implemented by a subclass. It should start an asynchronous write /// to the socket and return immediately virtual void asyncWriteN(char const * ptr, size_t length, writeHandler_t writeHandler) = 0; /// \brief pure virtual function that should start an asynchronous write to the socket /// /// This function needs to be implemented by a subclass. It should start an asynchronous write /// to the socket and return immediately virtual void asyncWriteN(ConnectionAbc::writeBuffer_t const& buffer, writeHandler_t writeHandler) = 0; /// \brief pure virtual function that should synchronously read data from the socket until the delimitor is found /// /// This function needs to be implemented by a subclass. It should block until it has /// read 1 or more bytes until the delimotor is read from the socket or an error occurs. /// Note: there can be more data read beyon the delimitor, so the caller needs to handle /// that case /// /// \return /// \li \c bytes read on succes or throws an exception on error virtual std::size_t readUntil(boost::asio::streambuf& streamBuf, char const* delimitor) = 0; /// \brief pure virtual function that should start an asynchronously read from the socket /// /// This function needs to be implemented by a subclass. It should start an asynchronous read /// from the socket and return immediately virtual void asyncReadUntil(boost::asio::streambuf& streamBuf, char const* delimitor, readHandler_t readHandler) = 0; /// \brief pure virtual function that should synchronously read data from the socket /// /// This function needs to be implemented by a subclass. It should block until it has /// read 1 or more (up to length) bytes from the socket or an error occurs. /// /// \return /// \li \c bytes read on succes or throws an exception on error virtual std::size_t readSome(char * ptr, std::size_t length) = 0; /// \brief pure virtual function that should start an asynchronously read from the socket /// /// This function needs to be implemented by a subclass. It should start an asynchronous read /// from the socket and return immediately virtual void asyncReadSome(char * ptr, std::size_t length, readHandler_t readHandler) = 0; /// \brief pure virtual function that should start an asynchronously read from the socket /// /// This function needs to be implemented by a subclass. It should start an asynchronous read /// from the socket and return immediately virtual void asyncRead(char * ptr, std::size_t length, readHandler_t readHandler) = 0; virtual std::size_t read(char * ptr, std::size_t length) = 0; /// \brief get the lowest layer socket /// /// \return /// \li \c lowest layer socket virtual boost::asio::ip::tcp::socket::lowest_layer_type& lowestLayerSocket() = 0; /// \brief does ssl hand shake virtual void handshake(boost::asio::ssl::stream_base::handshake_type handshakeType ///< client or server type ) = 0; /// \brief does ansynchronous ssl hand shake virtual void asyncHandshake(boost::asio::ssl::stream_base::handshake_type handshakeType, ///< client or server type handshakeHandler_t asyncHandshakeHandler ///< callback when completed ) = 0; /// \brief open the socket and set window sizes as needed virtual void socketOpen(boost::asio::ip::tcp::endpoint const& endpoint, ///< peer to connect to int sendWindowSizeBytes, ///< tcp send window size to use (overrides system setting) int receiveWindowSizeBytes ///< tcp receive winow size to use (overrides system setting) ) = 0; /// \brief connects to the given endpoint virtual void connect(boost::asio::ip::tcp::endpoint const& endpoint ///< endpoint to connect to ) = 0; /// \brief connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void connect(boost::asio::ip::tcp::endpoint const& endpoint, int sendWindowSizeBytes, int receiveWindowSizeBytes) = 0; /// \brief async connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void asyncConnect(boost::asio::ip::tcp::endpoint const& endpoint, int sendWindowSizeBytes, int receiveWindowSizeBytes, connectHandler_t connectHandler) = 0; /// \brief disconnects from peer virtual void disconnect() = 0; /// \brief close socket /// (should only used by cfs connections to make sure socket gets closed with out /// affecting the duplicated one passed back to agent) virtual void closeSocket() = 0; /// \brief returns the ip address of the peer virtual std::string peerIpAddress() = 0; /// \brief returns the connection endpoint info formated as string virtual std::string endpointInfoAsString() = 0; /// \brief returns the ip adresos of the peer virtual void connectionEndpoints(ConnectionEndpoints& connectionEndpoints) = 0; /// \brief sets the peer ip addrees for internal tracking /// /// if using async connect, needs to be called in the connectHandler_t on successful connction /// sync connect will automatically call it so no need for callers of connect to use it virtual void setPeerIpAddress() = 0; virtual std::string getCertBiosId() = 0; virtual bool isOpen() = 0; /// \brief cancel async socket request virtual boost::system::error_code cancel() = 0; /// \brief returns true if connection has timed out else false bool isTimedOut() { return m_timedOut; } /// \brief sets connection to timed out void setTimedOut() { m_timedOut = true; } /// \brief sets connection to timed out virtual void clearTimedOut() { m_timedOut = false; } void sslHandshakeCompleted() { m_sslShutdownNeeded = true; } void sslShutdownCompleted() { m_sslShutdownNeeded = false; } bool sslShutdownNeeded() { return m_sslShutdownNeeded; } virtual bool usingSsl() = 0; boost::asio::io_service& ioService() { return (boost::asio::io_context&)lowestLayerSocket().get_executor().context(); } virtual std::string getCurrentCipherSuite() = 0; virtual bool setSocketTimeouts(int milliseconds) = 0; protected: private: bool m_timedOut; ///< for timing out connections bool m_sslShutdownNeeded; ///< detetimes if ssl shutdown should be issued }; /// \brief basic connection for managing tcp connections template<typename SOCKET> class BasicConnection : public ConnectionAbc { public: /// \brief constructor explicit BasicConnection(boost::asio::io_service& ioService) ///< io service for the connection : m_strand(ioService), m_eof(false) {} virtual ~BasicConnection() { } /// \brief returns eof state /// /// \returns /// \li true: eof /// \li false: not eof bool eof() { return m_eof; } /// \brief writes length bytes /// /// blocks until length bytes written or an error /// /// \returns /// \li \c number of bytes written (which will always be length) /// \exception boost::system::system_error on error virtual std::size_t writeN(char const * ptr, ///< points to data to write (must be at least length bytes) std::size_t length) ///< length of buffer to write { return boost::asio::write(socket(), boost::asio::buffer(ptr, length)); } /// \brief writes length bytes /// /// blocks until length bytes written or an error /// /// \returns /// \li \c number of bytes written (which will always be length) /// \exception boost::system::system_error on error virtual std::size_t writeN(ConnectionAbc::writeBuffer_t const& buffer) ///< holds data to be written { return boost::asio::write(socket(), buffer); } /// \brief synchronously read data from the socket until the delimitor is found /// /// This function needs to be implemented by a subclass. It should block until it has /// read 1 or more bytes until the delimotor is read from the socket or an error occurs. /// Note: there can be more data read beyon the delimitor, so the caller needs to handle /// that case /// /// \return /// \li \c bytes read on succes or throws an exception on error virtual std::size_t readUntil(boost::asio::streambuf& buffer, char const* delimitor) { return boost::asio::read_until(socket(), buffer, delimitor); } /// \brief start an asynchronously read from the socket /// virtual void asyncReadUntil(boost::asio::streambuf& streamBuf, char const* delimitor, readHandler_t readHandler) { async_read_until(socket(), streamBuf, delimitor, m_strand.wrap(readHandler)); } /// \brief reads some bytes from the socket /// /// blocks until at least 1 byte has been read or an error /// /// \returns /// \li \c number of bytes read (0 indicates eof) /// \exception ERROR_EXCEPTION on error virtual std::size_t readSome(char * ptr, ///< pointer to buffer to receive the read data std::size_t length) ///< length of the buffer { boost::system::error_code error; std::size_t bytesRead = socket().read_some(boost::asio::buffer(ptr, length), error); if (0 == bytesRead) { if (boost::asio::error::eof == error) { m_eof = true; } else { throw ERROR_EXCEPTION << "error reading data from socket: " << error; } } return bytesRead; } /// \brief asynchronously read some bytes from the socket /// /// read some bytes from the socket and then calls the readHandler. use this function when you do not /// know how many bytes are expected, but that there are some bytes expected. virtual void asyncReadSome(char * ptr, ///< pointer to hold the read data std::size_t length, ///< length of buffer readHandler_t readHandler) ///< function to call when the aysnc read completes { socket().async_read_some(boost::asio::buffer(ptr, length), m_strand.wrap(readHandler)); } /// \brief asynchronously read up to length bytes from the socket /// /// reads up to length bytes before calling readHandler. use this function when you know /// how many bytes are expected virtual void asyncRead(char * ptr, ///< pointer to hold the read data std::size_t length, ///< length of buffer readHandler_t readHandler) ///< function to call when the aysnc read completes { boost::asio::async_read(socket(), boost::asio::buffer(ptr, length), m_strand.wrap(readHandler)); } virtual std::size_t read(char * ptr, ///< pointer to hold the read data std::size_t length) ///< length of buffer { return boost::asio::read(socket(), boost::asio::buffer(ptr, length)); } /// \brief asynchronously write N bytes to the socket virtual void asyncWriteN(char const * ptr, ///< points to buffer holding data to write (must be at least legnth size) size_t length, ///< length of buffer to write writeHandler_t writeHandler) ///< function to call when the async write completes { boost::asio::async_write(socket(), boost::asio::buffer(ptr, length), m_strand.wrap(writeHandler)); } /// \brief asynchronously write N bytes to the socket virtual void asyncWriteN(ConnectionAbc::writeBuffer_t const& buffer, ///< points to buffer holding data to write (must be at least legnth size) writeHandler_t writeHandler) ///< function to call when the async write completes { if (buffer.size() > 0) { boost::asio::async_write(socket(), buffer, m_strand.wrap(writeHandler)); } } /// \brief get a reference to the lowest layer socket virtual boost::asio::ip::tcp::socket::lowest_layer_type& lowestLayerSocket() = 0; /// \brief perform an ssl handshake virtual void handshake(boost::asio::ssl::stream_base::handshake_type handshakeType ///< client or server type ) = 0; /// \brief perform an asynchronous ssl handshake virtual void asyncHandshake(boost::asio::ssl::stream_base::handshake_type handshakeType, ///< client or server type handshakeHandler_t asyncHandshakeHandler ///< callback when completed ) = 0; /// \brief get a copy of the peer ip address as a string virtual std::string peerIpAddress() { return m_peerIpAddress; } /// \brief get the connection endpoint info as a string virtual std::string endpointInfoAsString() { return m_endpointInfoAsString; } virtual void connectionEndpoints(ConnectionEndpoints& connectionEndpoints) { connectionEndpoints.m_remoteIpAddress = lowestLayerSocket().remote_endpoint().address().to_string(); connectionEndpoints.m_remotePort = lowestLayerSocket().remote_endpoint().port(); connectionEndpoints.m_localIpAddress = lowestLayerSocket().local_endpoint().address().to_string(); connectionEndpoints.m_localPort = lowestLayerSocket().local_endpoint().port(); } virtual void setPeerIpAddress() { if (lowestLayerSocket().is_open()) { // can be open but not connected, remote_endpoint will throw if not connected try { m_peerIpAddress = lowestLayerSocket().remote_endpoint().address().to_string(); m_endpointInfoAsString = "local ip: "; m_endpointInfoAsString += lowestLayerSocket().local_endpoint().address().to_string(); m_endpointInfoAsString += ", local port: "; m_endpointInfoAsString += boost::lexical_cast<std::string>(lowestLayerSocket().local_endpoint().port()); m_endpointInfoAsString += ", remote ip: "; m_endpointInfoAsString += lowestLayerSocket().remote_endpoint().address().to_string(); m_endpointInfoAsString += ", remote port: "; m_endpointInfoAsString += boost::lexical_cast<std::string>(lowestLayerSocket().remote_endpoint().port()); m_endpointInfoAsString += " "; } catch (...) { } } } virtual std::string getCertBiosId() { return std::string(); } virtual bool isOpen() { return lowestLayerSocket().is_open(); } /// \brief disconnects from peer virtual void disconnect() { boost::system::error_code error; if (lowestLayerSocket().is_open()) { lowestLayerSocket().shutdown(boost::asio::ip::tcp::socket::shutdown_both, error); lowestLayerSocket().close(); } m_eof = false; } /// \brief disconnects from peer without calling shutdown to avoid issues when socket is passed between prosesses virtual void closeSocket() { boost::system::error_code error; if (lowestLayerSocket().is_open()) { lowestLayerSocket().close(); } m_eof = false; } /// \brief cancel async socket request virtual boost::system::error_code cancel() { boost::system::error_code ec; lowestLayerSocket().cancel(ec); return ec; } /// \brief open the socket and set window sizes as needed void socketOpen(boost::asio::ip::tcp::endpoint const& endpoint, ///< peer to connect to int sendWindowSizeBytes, ///< tcp send window size to use (overrides system setting) int receiveWindowSizeBytes) ///< tcp receive winow size to use (overrides system setting) { if (!lowestLayerSocket().is_open()) { lowestLayerSocket().open(endpoint.protocol()); } if (0 != receiveWindowSizeBytes) { lowestLayerSocket().set_option(boost::asio::socket_base::receive_buffer_size(receiveWindowSizeBytes)); } if (0 != sendWindowSizeBytes) { lowestLayerSocket().set_option(boost::asio::socket_base::send_buffer_size(sendWindowSizeBytes)); } } /// \brief connect to the given endpoint virtual void connect(boost::asio::ip::tcp::endpoint const& endpoint) ///< peer to connect to { connect(endpoint, 0, 0); } /// \brief connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void connect(boost::asio::ip::tcp::endpoint const& endpoint, ///< peer to connect to int sendWindowSizeBytes, ///< tcp send window size to use (overrides system setting) int receiveWindowSizeBytes) ///< tcp receive window size to use (overrides system setting) { socketOpen(endpoint, sendWindowSizeBytes, receiveWindowSizeBytes); lowestLayerSocket().connect(endpoint); setPeerIpAddress(); } /// \brief async connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void asyncConnect(boost::asio::ip::tcp::endpoint const& endpoint, int sendWindowSizeBytes, int receiveWindowSizeBytes, connectHandler_t connectHandler) { socketOpen(endpoint, sendWindowSizeBytes, receiveWindowSizeBytes); lowestLayerSocket().async_connect(endpoint, connectHandler); } virtual std::string getCurrentCipherSuite() { return std::string("Cipher: non ssl no cipher suite in use"); } virtual bool setSocketTimeouts(int milliseconds) { return setSocketTimeoutOptions(lowestLayerSocket().native_handle(), milliseconds); } protected: /// \brief get a reference to the socket virtual SOCKET& socket() = 0; private: boost::asio::io_service::strand m_strand; ///< used to protect agains bool m_eof; ///< eof indicator true: eof, false: not eof std::string m_peerIpAddress; ///< holds peer ip address to allow logging it even after the connection is no longer valid std::string m_endpointInfoAsString; ///< holds endpoint info formated for printing (ip: local ip, port: local port , peer ip: peer ip, peer port: peer port) }; /// \brief non-ssl socket type used when creating a Connection typedef boost::asio::ip::tcp::socket socket_t; /// \brief managing tcp connections that do not use ssl class Connection : public BasicConnection<socket_t> { public: explicit Connection(boost::asio::io_service& ioService) ///< io service for the connection : BasicConnection<socket_t>(ioService), m_socket(ioService) { } virtual ~Connection() { disconnect(); } /// \brief does nothing for non-ssl connections virtual void handshake(boost::asio::ssl::stream_base::handshake_type handshakeType) ///< client or server type { // nothing needed for non ssl } /// \brief does nothing for non-ssl connections virtual void asyncHandshake(boost::asio::ssl::stream_base::handshake_type handshakeType, ///< client or server type handshakeHandler_t asyncHandshakeHandler) ///< callback when completed { // nothing needed for non ssl } /// \brief get a reference to the lowest layer socket virtual socket_t::lowest_layer_type& lowestLayerSocket() { return socket().lowest_layer(); } virtual bool usingSsl() { return false; } protected: /// \brief get a reference to the socket virtual socket_t& socket() { return m_socket; } private: socket_t m_socket; //< holds the non-ssl socket object }; /// \brief managing tcp connections using openssl class SslConnection : public BasicConnection<sslSocket_t> { public: /// async ssl disconnect callback function type typedef boost::function<void (boost::system::error_code const & error)> sslShutdownHandler_t; #if 0 explicit SslConnection(boost::asio::io_service& ioService, ///< io service to use boost::asio::ssl::context& sslContext) ///< ssl context information (e.g. pem files) : BasicConnection<sslSocket_t>(ioService), m_socket(ioService, sslContext) { } #endif explicit SslConnection(boost::asio::io_service& ioService, std::string const& certFile, std::string const& keyFile, std::string const& dhFile, std::string const& passphrase, std::string const& caCertThumbPrint) : BasicConnection<sslSocket_t>(ioService), m_sslContext(ioService, certFile, keyFile, dhFile, passphrase, caCertThumbPrint), m_socket(ioService, m_sslContext.context()) { } explicit SslConnection(boost::asio::io_service& ioService, std::string const& clientFile) : BasicConnection<sslSocket_t>(ioService), m_sslContext(ioService, clientFile), m_socket(ioService, m_sslContext.context()) { } explicit SslConnection(boost::asio::io_service& ioService, std::string const& certFile, std::string const& keyFile, std::string const& serverCertThumbprint) : BasicConnection<sslSocket_t>(ioService), m_sslContext(ioService, certFile, keyFile, serverCertThumbprint), m_socket(ioService, m_sslContext.context()) { } virtual ~SslConnection() { disconnect(); } /// \brief perform ssl handshake virtual void handshake(boost::asio::ssl::stream_base::handshake_type handshakeType) ///< client or server type { m_socket.handshake(handshakeType); } /// \brief perform asynchronous ssl handshake virtual void asyncHandshake(boost::asio::ssl::stream_base::handshake_type handshakeType, ///< client or server type handshakeHandler_t asyncHandshakeHandler) ///< callback when completed { m_socket.async_handshake(handshakeType, asyncHandshakeHandler); } /// \brief get a reference to the lowest layer socket virtual sslSocket_t::lowest_layer_type& lowestLayerSocket() { return socket().lowest_layer(); } /// \brief connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void sslHandshake() { handshake(boost::asio::ssl::stream_base::client); } /// \brief connects to the given endpoint /// /// before connecting sets the provided window sizes on the socket as that needs to be done /// before connection /// /// \note /// \li \c 0 for a window size means use system default value virtual void asyncSslHandshake(handshakeHandler_t asyncHandshakeHandler) ///< callback when complete { asyncHandshake(boost::asio::ssl::stream_base::client, asyncHandshakeHandler); } /// \brief async ssl shutdown /// /// \return /// \li \c true if the async hand shake was issued /// \li \c fals if the async hand shake was not issued (because socket was not opened bool asyncSslShutdown(sslShutdownHandler_t asyncSslShutdownHandler) ///< callback when complete { if (m_socket.lowest_layer().is_open() && sslShutdownNeeded()) { m_socket.async_shutdown(asyncSslShutdownHandler); return true; } return false; } /// \brief sync ssl shutdown /// /// \return /// \li \c true if the async hand shake was issued /// \li \c fals if the async hand shake was not issued (because socket was not opened bool sslShutdown() { if (m_socket.lowest_layer().is_open() && sslShutdownNeeded()) { m_socket.shutdown(); return true; } return false; } virtual bool usingSsl() { return true; } /// \brief get a reference to the socket virtual sslSocket_t& socket() { return m_socket; } bool verifyFingerprint() { return m_sslContext.verifyFingerprint(m_socket.native_handle()); } std::string getFingerprint() { return m_sslContext.getFingerprint(m_socket.native_handle()); } std::string getCertificate() { return m_sslContext.getCertificate(m_socket.native_handle()); } virtual std::string getCurrentCipherSuite() { return m_sslContext.getCurrentCipherSuite(m_socket.native_handle()); } virtual std::string getCertBiosId() { return m_sslContext.getCertBiosId(); } protected: private: // must alway be before m_socket CxpsSslContext m_sslContext; ///< ssl context for this connection sslSocket_t m_socket; ///< holds the ssl socket object }; #endif // CONNECTION_H