platform/networkstrate/async_acceptor.cpp (122 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "platform/networkstrate/async_acceptor.h"
#include <glog/logging.h>
#include <signal.h>
#include <thread>
namespace resdb {
AsyncAcceptor::Session::Session(boost::asio::io_service* io_service,
CallBack call_back_func)
: io_service_(io_service),
client_socket_(*io_service_),
recv_buffer_(nullptr),
status_(0),
call_back_func_(call_back_func) {}
AsyncAcceptor::Session::~Session() { Close(); }
boost::asio::ip::tcp::socket* AsyncAcceptor::Session::GetSocket() {
return &client_socket_;
}
void AsyncAcceptor::Session::Close() {
if (client_socket_.is_open()) {
client_socket_.cancel();
}
if (recv_buffer_ && status_ != 0) {
delete recv_buffer_;
recv_buffer_ = nullptr;
}
}
void AsyncAcceptor::Session::StartRead() {
if (status_ == 0) {
// read len
recv_buffer_ = reinterpret_cast<char*>(&data_size_);
data_size_ = 0;
need_size_ = sizeof(data_size_);
current_idx_ = 0;
memset(recv_buffer_, 0, need_size_);
OnRead();
} else {
need_size_ = data_size_;
current_idx_ = 0;
if (recv_buffer_) {
delete recv_buffer_;
}
recv_buffer_ = new char[need_size_];
memset(recv_buffer_, 0, need_size_);
OnRead();
}
}
void AsyncAcceptor::Session::ReadDone() {
if (status_ == 1) {
call_back_func_(recv_buffer_, data_size_);
delete recv_buffer_;
} else {
data_size_ = *reinterpret_cast<size_t*>(recv_buffer_);
if (data_size_ > 1e10) {
LOG(ERROR) << "read data size:" << data_size_
<< " data size:" << sizeof(data_size_) << " close socket";
Close();
return;
}
}
status_ ^= 1;
recv_buffer_ = nullptr;
// continue to read next msg.
StartRead();
}
void AsyncAcceptor::Session::OnRead() {
client_socket_.async_read_some(
boost::asio::buffer(recv_buffer_ + current_idx_,
need_size_ - current_idx_),
[&](const boost::system::error_code& error, // Result of operation.
std::size_t bytes_transferred) {
if (error || bytes_transferred == 0) {
Close();
} else {
current_idx_ += bytes_transferred;
if (current_idx_ >= need_size_) {
ReadDone();
} else {
OnRead();
}
}
});
}
AsyncAcceptor::AsyncAcceptor(const std::string& ip, int port, int thread_num,
CallBack call_back_func)
: endpoint_(boost::asio::ip::address::from_string(ip), port),
acceptor_(io_service_, endpoint_),
call_back_func_(call_back_func) {
worker_ = std::make_unique<boost::asio::io_service::work>(io_service_);
for (int i = 0; i < thread_num; ++i) {
worker_thread_.push_back(std::thread([&]() { io_service_.run(); }));
}
}
AsyncAcceptor::~AsyncAcceptor() {
worker_.reset();
worker_ = nullptr;
io_service_.stop();
for (auto& sess : sessions_) {
if (sess) {
sess->Close();
}
}
for (auto& worker : worker_thread_) {
if (worker.joinable()) {
worker.join();
}
}
}
void AsyncAcceptor::StartAccept() {
boost::shared_ptr<Session> client_session(
new Session(&io_service_, call_back_func_));
acceptor_.async_accept(*client_session->GetSocket(),
std::bind(&AsyncAcceptor::OnAccept, this,
client_session, std::placeholders::_1));
}
void AsyncAcceptor::OnAccept(boost::shared_ptr<Session> client_session,
const boost::system::error_code ec) {
if (ec) {
LOG(ERROR) << " accept fail";
return;
}
StartAccept(); // Add the next accept event.
sessions_.push_back(client_session);
client_session->StartRead();
}
} // namespace resdb