nlsCppSdk/transport/SSLconnect.cpp (324 lines of code) (raw):
/*
* Copyright 2021 Alibaba Group Holding Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SSLconnect.h"
#include <stdio.h>
#include <string.h>
#include "connectNode.h"
#include "nlog.h"
#include "nlsGlobal.h"
#include "openssl/err.h"
#include "utility.h"
namespace AlibabaNls {
SSL_CTX *SSLconnect::_sslCtx = NULL;
SSLconnect::SSLconnect() : _ssl(NULL), _sslTryAgain(0), _errorMsg() {
#if defined(_MSC_VER)
_mtxSSL = CreateMutex(NULL, FALSE, NULL);
#else
pthread_mutex_init(&_mtxSSL, NULL);
#endif
LOG_DEBUG("Create SSLconnect:%p.", this);
}
SSLconnect::~SSLconnect() {
sslClose();
_sslTryAgain = 0;
#if defined(_MSC_VER)
CloseHandle(_mtxSSL);
#else
pthread_mutex_destroy(&_mtxSSL);
#endif
LOG_DEBUG("SSL(%p) Destroy SSLconnect done.", this);
}
int SSLconnect::init() {
if (_sslCtx == NULL) {
_sslCtx = SSL_CTX_new(SSLv23_client_method());
if (_sslCtx == NULL) {
LOG_ERROR("SSL: couldn't create a context!");
exit(1);
}
}
SSL_CTX_set_verify(_sslCtx, SSL_VERIFY_NONE, NULL);
SSL_CTX_set_mode(_sslCtx, SSL_MODE_ENABLE_PARTIAL_WRITE |
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
SSL_MODE_AUTO_RETRY);
LOG_DEBUG("SSLconnect::init() done.");
return Success;
}
void SSLconnect::destroy() {
if (_sslCtx) {
// LOG_DEBUG("free _sslCtx.");
SSL_CTX_free(_sslCtx);
_sslCtx = NULL;
}
LOG_DEBUG("SSLconnect::destroy() done.");
}
int SSLconnect::sslHandshake(int socketFd, const char *hostname) {
// LOG_DEBUG("Begin sslHandshake.");
if (_sslCtx == NULL) {
LOG_ERROR("SSL(%p) _sslCtx has been released.", this);
return -(SslCtxEmpty);
}
MUTEX_LOCK(_mtxSSL);
int ret;
if (_ssl == NULL) {
_ssl = SSL_new(_sslCtx);
if (_ssl == NULL) {
memset(_errorMsg, 0x0, MaxSslErrorLength);
const char *SSL_new_ret = "return of SSL_new: ";
const int SSL_new_str_size = strnlen(SSL_new_ret, 24);
memcpy(_errorMsg, SSL_new_ret, SSL_new_str_size);
ERR_error_string_n(ERR_get_error(), _errorMsg + SSL_new_str_size,
MaxSslErrorLength - SSL_new_str_size - 1);
LOG_ERROR("SSL(%p) Invoke SSL_new failed:%s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslNewFailed);
} else {
if (hostname) {
if (!SSL_set_tlsext_host_name(_ssl, hostname)) {
LOG_ERROR("Error setting SNI host name");
} else {
LOG_INFO("Set SNI %s success", hostname);
}
}
}
ret = SSL_set_fd(_ssl, socketFd);
if (ret == 0) {
memset(_errorMsg, 0x0, MaxSslErrorLength);
const char *SSL_set_fd_ret = "return of SSL_set_fd: ";
const int SSL_set_fd_str_size = strnlen(SSL_set_fd_ret, 24);
memcpy(_errorMsg, SSL_set_fd_ret, SSL_set_fd_str_size);
ERR_error_string_n(ERR_get_error(), _errorMsg + SSL_set_fd_str_size,
MaxSslErrorLength - SSL_set_fd_str_size - 1);
LOG_ERROR("SSL(%p) Invoke SSL_set_fd failed:%s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslSetFailed);
}
SSL_set_mode(_ssl, SSL_MODE_ENABLE_PARTIAL_WRITE |
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
SSL_MODE_AUTO_RETRY);
SSL_set_connect_state(_ssl);
} else {
// LOG_DEBUG("SSL has existed.");
}
ret = SSL_connect(_ssl);
if (ret < 0) {
int sslError = SSL_get_error(_ssl, ret);
/* err == SSL_ERROR_ZERO_RETURN
"SSL_connect: close notify received from peer" */
// sslError == SSL_ERROR_WANT_X509_LOOKUP
// SSL_ERROR_SYSCALL
if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
// LOG_DEBUG("sslHandshake continue.");
MUTEX_UNLOCK(_mtxSSL);
return sslError;
} else if (sslError == SSL_ERROR_SYSCALL) {
int errno_code = utility::getLastErrorCode();
LOG_INFO("SSL(%p) SSL connect error_syscall failed, errno:%d.", this,
errno_code);
if (NLS_ERR_CONNECT_RETRIABLE(errno_code) ||
NLS_ERR_RW_RETRIABLE(errno_code)) {
MUTEX_UNLOCK(_mtxSSL);
return SSL_ERROR_WANT_READ;
} else if (errno_code == 0) {
LOG_DEBUG("SSL(%p) SSL connect syscall success.", this);
MUTEX_UNLOCK(_mtxSSL);
return Success;
} else {
MUTEX_UNLOCK(_mtxSSL);
return -(SslConnectFailed);
}
} else {
memset(_errorMsg, 0x0, MaxSslErrorLength);
const char *SSL_connect_ret = "return of SSL_connect: ";
const int SSL_connect_str_size = strnlen(SSL_connect_ret, 64);
memcpy(_errorMsg, SSL_connect_ret, SSL_connect_str_size);
ERR_error_string_n(ERR_get_error(), _errorMsg + SSL_connect_str_size,
MaxSslErrorLength - SSL_connect_str_size - 1);
LOG_ERROR("SSL(%p) SSL connect failed:%s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
this->sslClose();
return -(SslConnectFailed);
}
} else {
// LOG_DEBUG("sslHandshake success.");
MUTEX_UNLOCK(_mtxSSL);
return Success;
}
MUTEX_UNLOCK(_mtxSSL);
return Success;
}
int SSLconnect::sslWrite(const uint8_t *buffer, size_t len) {
MUTEX_LOCK(_mtxSSL);
if (_ssl == NULL) {
LOG_ERROR("SSL(%p) ssl has been closed.", this);
MUTEX_UNLOCK(_mtxSSL);
return -(SslWriteFailed);
}
int wLen = SSL_write(_ssl, (void *)buffer, (int)len);
if (wLen < 0) {
int sslError = SSL_get_error(_ssl, wLen);
int errno_code = utility::getLastErrorCode();
char sslErrMsg[MaxSslErrorLength] = {0};
const char *SSL_write_ret = "return of SSL_write: ";
const int SSL_write_str_size = strnlen(SSL_write_ret, 64);
if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
LOG_DEBUG("SSL(%p) Write could not complete. Will be invoked later.",
this);
MUTEX_UNLOCK(_mtxSSL);
return 0;
} else if (sslError == SSL_ERROR_SYSCALL) {
LOG_INFO("SSL(%p) SSL_write error_syscall failed, errno:%d.", this,
errno_code);
if (NLS_ERR_CONNECT_RETRIABLE(errno_code) ||
NLS_ERR_RW_RETRIABLE(errno_code)) {
MUTEX_UNLOCK(_mtxSSL);
return 0;
} else if (errno_code == 0) {
LOG_DEBUG("SSL(%p) SSL_write syscall success.", this);
MUTEX_UNLOCK(_mtxSSL);
return 0;
#ifdef _MSC_VER
} else if (errno_code == WSAECONNRESET) {
#else
} else if (errno_code == ECONNRESET) {
#endif
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_write_ret, SSL_write_str_size);
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_write_str_size,
MaxSslErrorLength - SSL_write_str_size - 1);
snprintf(_errorMsg, MaxSslErrorLength,
"%s. It's mean the remote end was "
"closed because of bad network. errno_code:%d, ssl_eCode:%d.",
sslErrMsg, errno_code, sslError);
LOG_ERROR("SSL(%p) SSL_ERROR_SYSCALL Write failed, %s.", this,
_errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslWriteFailed);
} else {
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_write_ret, SSL_write_str_size);
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_write_str_size,
MaxSslErrorLength - SSL_write_str_size - 1);
snprintf(_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d ssl_eCode:%d.", sslErrMsg, errno_code,
sslError);
LOG_ERROR("SSL(%p) SSL_ERROR_SYSCALL Write failed: %s.", this,
_errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslWriteFailed);
}
} else {
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_write_ret, SSL_write_str_size);
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_write_str_size,
MaxSslErrorLength - SSL_write_str_size - 1);
if (sslError == SSL_ERROR_ZERO_RETURN && errno_code == 0) {
snprintf(
_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d ssl_eCode:%d. It's mean this connection was "
"closed or shutdown because of bad network.",
sslErrMsg, errno_code, sslError);
} else {
snprintf(_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d ssl_eCode:%d.", sslErrMsg, errno_code,
sslError);
}
LOG_ERROR("SSL(%p) SSL_write failed: %s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslWriteFailed);
}
}
MUTEX_UNLOCK(_mtxSSL);
return wLen;
}
int SSLconnect::sslRead(uint8_t *buffer, size_t len) {
MUTEX_LOCK(_mtxSSL);
if (_ssl == NULL) {
LOG_ERROR("SSL(%p) ssl has been closed.", this);
MUTEX_UNLOCK(_mtxSSL);
return -(SslReadFailed);
}
int rLen = SSL_read(_ssl, (void *)buffer, (int)len);
if (rLen <= 0) {
int sslError = SSL_get_error(_ssl, rLen);
int errno_code = utility::getLastErrorCode();
char sslErrMsg[MaxSslErrorLength] = {0};
const char *SSL_read_ret = "return of SSL_read: ";
const int SSL_read_str_size = strnlen(SSL_read_ret, 64);
// LOG_WARN("Read maybe failed, get_ssl_error:%d", sslError);
if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE ||
sslError == SSL_ERROR_WANT_X509_LOOKUP) {
// LOG_DEBUG("SSL(%p) Read could not complete. Will be invoked later.",
// this);
MUTEX_UNLOCK(_mtxSSL);
return 0;
} else if (sslError == SSL_ERROR_SYSCALL) {
LOG_INFO("SSL(%p) SSL_read error_syscall failed, errno:%d.", this,
errno_code);
if (NLS_ERR_CONNECT_RETRIABLE(errno_code) ||
NLS_ERR_RW_RETRIABLE(errno_code)) {
LOG_WARN("SSL(%p) Retry read...", this);
MUTEX_UNLOCK(_mtxSSL);
return 0;
} else if (errno_code == 0) {
LOG_DEBUG("SSL(%p) SSL_read syscall success.", this);
MUTEX_UNLOCK(_mtxSSL);
return 0;
#ifdef _MSC_VER
} else if (errno_code == WSAECONNRESET) {
#else
} else if (errno_code == ECONNRESET) {
#endif
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_read_ret, SSL_read_str_size);
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_read_str_size,
MaxSslErrorLength - SSL_read_str_size - 1);
snprintf(_errorMsg, MaxSslErrorLength,
"%s. It's mean the remote end was "
"closed because of bad network. errno_code:%d, ssl_eCode:%d.",
sslErrMsg, errno_code, sslError);
LOG_ERROR("SSL(%p) SSL_ERROR_SYSCALL Read failed, %s.", this,
_errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslReadSysError);
} else {
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_read_ret, SSL_read_str_size);
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_read_str_size,
MaxSslErrorLength - SSL_read_str_size - 1);
snprintf(_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d, ssl_eCode:%d.", sslErrMsg, errno_code,
sslError);
LOG_ERROR("SSL(%p) SSL_ERROR_SYSCALL Read failed, %s.", this,
_errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslReadSysError);
}
} else {
memset(_errorMsg, 0x0, MaxSslErrorLength);
memcpy(sslErrMsg, SSL_read_ret, strnlen(SSL_read_ret, 64));
ERR_error_string_n(ERR_get_error(), sslErrMsg + SSL_read_str_size,
MaxSslErrorLength - SSL_read_str_size - 1);
if (sslError == SSL_ERROR_ZERO_RETURN && errno_code == 0 &&
++_sslTryAgain <= MaxSslTryAgain) {
snprintf(
_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d ssl_eCode:%d. It's mean this connection was "
"closed or shutdown because of bad network, Try again ...",
sslErrMsg, errno_code, sslError);
LOG_WARN("SSL(%p) SSL_read failed: %s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return 0;
} else {
snprintf(_errorMsg, MaxSslErrorLength,
"%s. errno_code:%d ssl_eCode:%d.", sslErrMsg, errno_code,
sslError);
}
LOG_ERROR("SSL(%p) SSL_read failed: %s.", this, _errorMsg);
MUTEX_UNLOCK(_mtxSSL);
return -(SslReadFailed);
}
}
_sslTryAgain = 0;
MUTEX_UNLOCK(_mtxSSL);
return rLen;
}
/**
* @brief: 关闭TLS/SSL连接
* @return:
*/
void SSLconnect::sslClose() {
MUTEX_LOCK(_mtxSSL);
if (_ssl) {
LOG_INFO("SSL(%p) ssl connect close.", this);
SSL_shutdown(_ssl);
SSL_free(_ssl);
_ssl = NULL;
} else {
LOG_DEBUG("SSL(%p) connect has closed.", this);
}
MUTEX_UNLOCK(_mtxSSL);
}
const char *SSLconnect::getFailedMsg() { return _errorMsg; }
} // namespace AlibabaNls