wangle/example/broadcast/BroadcastProxy.cpp (135 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <folly/portability/GFlags.h> #include <folly/init/Init.h> #include <wangle/bootstrap/AcceptRoutingHandler.h> #include <wangle/bootstrap/RoutingDataHandler.h> #include <wangle/bootstrap/ServerBootstrap.h> #include <wangle/channel/AsyncSocketHandler.h> #include <wangle/channel/broadcast/BroadcastHandler.h> #include <wangle/channel/broadcast/BroadcastPool.h> #include <wangle/channel/broadcast/ObservingHandler.h> #include <wangle/codec/ByteToMessageDecoder.h> #include <wangle/codec/MessageToByteEncoder.h> using namespace folly; using namespace wangle; DEFINE_int32(port, 8080, "Broadcast proxy port"); DEFINE_int32(upstream_port, 8081, "Upstream server port"); /** * Steps to run: * 1) Run an upstream server that can broadcast messages: * * nc -l localhost 8081 * * This starts a server on localhost:8081. * * 2) Start the broadcast proxy with the upstream_port set to 8081: * * ./broadcast_proxy --port 8080 --upstream_port 8081 * * This starts the proxy on localhost:8080 and sets the upstream server * as localhost:8081 * * 3) Start a new instances of telnet clients to connect to the broadcast proxy * and listen to the messages broadcasted by the upstream server: * * telnet localhost 8080 * * Send some bytes in the telnet terminals for broadcast_proxy to kick off * the connection. * * 4) Type something in the nc terminal and notice that it is broadcasted to all * the telnet clients. */ /** * A simple decoder that decodes bytes in IOBufQueue to std::string. * This is used in the BroadcastPipeline to convert bytes read from the * upstream server's socket to strings of messages that can be broadcasted * to all the clients/observers. */ class ByteToStringDecoder : public ByteToMessageDecoder<std::string> { public: bool decode(Context*, IOBufQueue& buf, std::string& result, size_t&) override { if (buf.chainLength() > 0) { result = buf.move()->moveToFbString().toStdString(); return true; } return false; } }; /** * A simple encoder that encodes strings of messages to IOBuf. * This is used in the ObservingPipeline to encode the messages * broadcasted by the upstream to IOBuf so that it can be written * to the client socket. */ class StringToByteEncoder : public MessageToByteEncoder<std::string> { public: std::unique_ptr<folly::IOBuf> encode(std::string& msg) override { return IOBuf::copyBuffer(msg); } }; /** * Simple RoutingDataHandler that sets the client IP as the routing data. * All requests from the same client IP will be hashed to the same worker * thread. */ class ClientIPRoutingDataHandler : public RoutingDataHandler<std::string> { public: ClientIPRoutingDataHandler(uint64_t connId, Callback* cob) : RoutingDataHandler<std::string>(connId, cob) {} bool parseRoutingData(folly::IOBufQueue& bufQueue, RoutingData& routingData) override { auto transportInfo = getContext()->getPipeline()->getTransportInfo(); const auto& clientIP = transportInfo->remoteAddr->getAddressStr(); LOG(INFO) << "Using client IP " << clientIP << " as routing data to hash to a worker thread"; routingData.routingData = clientIP; routingData.bufQueue.append(bufQueue); return true; } }; class ClientIPRoutingDataHandlerFactory : public RoutingDataHandlerFactory<std::string> { public: std::shared_ptr<RoutingDataHandler<std::string>> newHandler( uint64_t connId, RoutingDataHandler<std::string>::Callback* cob) override { return std::make_shared<ClientIPRoutingDataHandler>(connId, cob); } }; /** * Implementation of a broadcast ServerPool that establishes connection * to an upstream server. */ class SimpleServerPool : public ServerPool<std::string> { public: Future<DefaultPipeline*> connect( BaseClientBootstrap<DefaultPipeline>* client, const std::string& /* routingData */) noexcept override { SocketAddress address; address.setFromLocalPort(FLAGS_upstream_port); LOG(INFO) << "Connecting to upstream server " << address << " for subscribing to broadcast"; return client->connect(address); } }; /** * BroadcastPipeline maintains the upstream connection and broadcasts * messages sent by the upstream server to all the observers/clients. */ class SimpleBroadcastPipelineFactory : public BroadcastPipelineFactory<std::string, std::string> { public: DefaultPipeline::Ptr newPipeline( std::shared_ptr<AsyncTransport> socket) override { LOG(INFO) << "Creating a new BroadcastPipeline for upstream server"; auto pipeline = DefaultPipeline::create(); pipeline->addBack(AsyncSocketHandler(socket)); pipeline->addBack(ByteToStringDecoder()); pipeline->addBack(BroadcastHandler<std::string, std::string>()); pipeline->finalize(); return pipeline; } BroadcastHandler<std::string, std::string>* getBroadcastHandler( DefaultPipeline* pipeline) noexcept override { return pipeline->getHandler<BroadcastHandler<std::string, std::string>>(); } void setRoutingData( DefaultPipeline* /* pipeline */, const std::string& /* routingData */) noexcept override {} }; using SimpleObservingPipeline = ObservingPipeline<std::string>; /** * An ObservingPipeline that maintains the client socket connection and * subscribes to the BroadcastPipeline to receive messages sent by the * upstream server. A new ObservingPipeline is created for each client * connection. */ class SimpleObservingPipelineFactory : public ObservingPipelineFactory<std::string, std::string> { public: SimpleObservingPipelineFactory( std::shared_ptr<SimpleServerPool> serverPool, std::shared_ptr<SimpleBroadcastPipelineFactory> broadcastPipelineFactory) : ObservingPipelineFactory<std::string, std::string>( serverPool, broadcastPipelineFactory) {} SimpleObservingPipeline::Ptr newPipeline( std::shared_ptr<AsyncTransport> socket, const std::string& routingData, RoutingDataHandler<std::string>*, std::shared_ptr<TransportInfo> transportInfo) override { LOG(INFO) << "Creating a new ObservingPipeline for client " << *(transportInfo->remoteAddr); auto pipeline = SimpleObservingPipeline::create(); pipeline->addBack(AsyncSocketHandler(socket)); pipeline->addBack(StringToByteEncoder()); pipeline->addBack( std::make_shared<ObservingHandler<std::string, std::string>>( routingData, broadcastPool())); pipeline->finalize(); return pipeline; } }; int main(int argc, char** argv) { folly::Init init(&argc, &argv); auto serverPool = std::make_shared<SimpleServerPool>(); // A unique BroadcastPipeline for each upstream server to fan-out the // upstream messages to ObservingPipelines corresponding to each client. auto broadcastPipelineFactory = std::make_shared<SimpleBroadcastPipelineFactory>(); // A unique ObservingPipeline is created for each client to subscribe // to the broadcast. auto observingPipelineFactory = std::make_shared<SimpleObservingPipelineFactory>( serverPool, broadcastPipelineFactory); // RoutingDataHandlerFactory for creating the RoutingDataHandler that sets // client IP as the routing data. auto routingHandlerFactory = std::make_shared<ClientIPRoutingDataHandlerFactory>(); ServerBootstrap<SimpleObservingPipeline> server; // AcceptRoutingPipelineFactory for creating accept pipelines hash the // client connection to a worker thread based on client IP. auto acceptPipelineFactory = std::make_shared< AcceptRoutingPipelineFactory<SimpleObservingPipeline, std::string>>( &server, routingHandlerFactory, observingPipelineFactory); server.pipeline(acceptPipelineFactory); server.bind(FLAGS_port); server.waitForStop(); return 0; }