cpp-channel/Thrift/Channel/Lib/CppChannel.hsc (206 lines of code) (raw):

-- Copyright (c) Facebook, Inc. and its affiliates. {-# LANGUAGE CPP #-} {-# OPTIONS_GHC -fno-warn-unused-do-bind #-} {-# OPTIONS_GHC -fprof-auto #-} module Thrift.Channel.Lib.CppChannel ( WrappedChannel(..), CppRequestChannelPtr, CppSocketAddress , CppAsyncTransport, withCppChannelIO, withCppChannel , getInnerCppRequestChannel ) where import Control.Concurrent import Control.Exception import Control.Monad import Data.ByteString.Internal (ByteString(..)) import Data.ByteString.Unsafe ##if __GLASGOW_HASKELL__ < 804 import Data.Monoid ##endif import Data.Text.Encoding import Foreign hiding (void) import Foreign.C import GHC.Conc (newStablePtrPrimMVar, PrimMVar) import TextShow import Thrift.Channel import Thrift.Monad import Thrift.Protocol.Binary import Util.Control.Exception import Util.Log -- | Encapsulation for using the C++ thrift libraries to make requests -- -- NOTE, resource lifetime management is Hard with C++: -- * These must be used only within the scope of `withFacebook` and -- `withEventBaseDataplane`. -- * Do not use the global IOExecutor with Haskell Thrift channels. This way -- lies madness. -- Things that exist in C++ data CppWrappedChannel data CppRequestChannelPtr data CppSocketAddress data CppAsyncTransport #include <cpp/HsChannel.h> -- | WrappedChannel is parameterized by a phantom type that represents the -- specific CPP client channel we are using newtype WrappedChannel t s = WrappedChannel { cppChannel :: Ptr CppWrappedChannel } withCppChannel :: Ptr CppRequestChannelPtr -> ThriftM p (WrappedChannel t) s a -> IO a withCppChannel channel = withCppChannelIO channel . runThrift withCppChannelIO :: Ptr CppRequestChannelPtr -> (WrappedChannel t s -> IO a) -> IO a withCppChannelIO channel action = bracket (c_newWrapper channel) c_deleteWrapper $ \ch -> action (WrappedChannel ch) -- | Returns a raw pointer to the inner channel. -- This is only valid while the wrapped channel is alive. getInnerCppRequestChannel :: WrappedChannel t s -> IO (Ptr CppRequestChannelPtr) getInnerCppRequestChannel WrappedChannel{..} = c_getInnerRequestChannel cppChannel -------------------------------------------------------------------------------- instance ClientChannel (WrappedChannel t) where sendRequest WrappedChannel{..} Request{..} sendCob recvCob = mask_ $ do (send_mvar, send_sp, send_result) <- newCallbackMVar (recv_mvar, recv_sp, recv_result) <- newCallbackMVar (cap,_) <- threadCapability =<< myThreadId forkIO $ do cont <- sendCollector send_mvar send_result sendCob reqMsg when cont $ recvCollector recv_mvar recv_result recvCob withForeignPtr send_result $ \send_result_p -> do withForeignPtr recv_result $ \recv_result_p -> do unsafeUseAsCStringLen reqMsg $ \(buf, len) -> do unsafeUseAsCStringLen (serializeBinary reqOptions) $ \(oBuf, oLen) -> c_sendReq cppChannel buf (fromIntegral len) (fromIntegral cap) send_sp recv_sp send_result_p recv_result_p oBuf (fromIntegral oLen) sendOnewayRequest WrappedChannel{..} Request{..} sendCob = mask_ $ do (send_mvar, send_sp, send_result) <- newCallbackMVar forkIO $ void $ sendCollector send_mvar send_result sendCob reqMsg (cap,_) <- threadCapability =<< myThreadId withForeignPtr send_result $ \send_result_p -> do unsafeUseAsCStringLen reqMsg $ \(buf, len) -> do unsafeUseAsCStringLen (serializeBinary reqOptions) $ \(oBuf, oLen) -> c_sendOnewayReq cppChannel buf (fromIntegral len) (fromIntegral cap) send_sp send_result_p oBuf (fromIntegral oLen) sendCollector :: MVar () -> ForeignPtr CFinishedRequest -> SendCallback -> ByteString -> IO Bool sendCollector send_mvar send_result sendCob reqMsg = do takeMVar send_mvar touchReq reqMsg withForeignPtr send_result $ \ptr -> do statusi <- (#peek FinishedRequest, status) ptr :: IO CInt case statusi of (#const SEND_ERROR) -> do msg <- peekFinishedRequestMsg ptr catchAndLog $ sendCob $ Just $ ChannelException $ "sendCob: " <> decodeUtf8 msg return False (#const SEND_SUCCESS) -> do catchAndLog $ sendCob Nothing return True _ -> do sendCob (Just (ChannelException ("sendCollector: unexpected status: " <> showt statusi))) return False where catchAndLog io = io `catchAll` \e -> logError ("send callback threw: " ++ show e) recvCollector :: MVar () -> ForeignPtr CFinishedRequest -> RecvCallback -> IO () recvCollector recv_mvar recv_result recvCob = do takeMVar recv_mvar withForeignPtr recv_result $ \ptr -> do statusi <- (#peek FinishedRequest, status) ptr :: IO CInt msg <- peekFinishedRequestMsg ptr catchAndLog $ case statusi of (#const RECV_ERROR) -> recvCob $ Left $ ChannelException $ "recvCob: " <> decodeUtf8 msg (#const RECV_SUCCESS) -> recvCob (Right (Response msg mempty)) _ -> recvCob (Left (ChannelException ("recvCollector: unexpected status: " <> showt statusi))) where catchAndLog io = io `catchAll` \e -> logError ("recv callback threw: " ++ show e) -- We need the send callback to touch the request message so that it doesn't get -- garbage collected before the request is sent touchReq :: ByteString -> IO () touchReq (PS fptr _ _) = touchForeignPtr fptr -- The pieces we need to set up a callback from C to Haskell newCallbackMVar :: IO (MVar (), StablePtr PrimMVar, ForeignPtr CFinishedRequest) newCallbackMVar = do mvar <- newEmptyMVar sp <- newStablePtrPrimMVar mvar ptr <- mallocForeignPtrBytes (#const sizeof(FinishedRequest)) return (mvar, sp, ptr) data CFinishedRequest -- Pack the message bytes into a ByteString that will call free when -- it is garbage collected peekFinishedRequestMsg :: Ptr CFinishedRequest -> IO ByteString peekFinishedRequestMsg ptr = join $ curry unsafePackMallocCStringLen <$> (#peek FinishedRequest, buffer) ptr <*> (fromIntegral <$> ((#peek FinishedRequest, len) ptr :: IO CSize)) -------------------------------------------------------------------------------- foreign import ccall unsafe "newWrapper" c_newWrapper :: Ptr CppRequestChannelPtr -> IO (Ptr CppWrappedChannel) foreign import ccall unsafe "deleteWrapper" c_deleteWrapper :: Ptr CppWrappedChannel -> IO () foreign import ccall unsafe "getInnerRequestChannel" c_getInnerRequestChannel :: Ptr CppWrappedChannel -> IO (Ptr CppRequestChannelPtr) -- This is implemented using runInEventBaseThread(), which is -- non-blocking, so we can make this call unsafe. foreign import ccall unsafe "sendReq" c_sendReq :: Ptr CppWrappedChannel -> CString -> CSize -> CInt -> StablePtr PrimMVar -> StablePtr PrimMVar -> Ptr CFinishedRequest -> Ptr CFinishedRequest -- RPC Options -> CString -> CSize -> IO () -- This is implemented using runInEventBaseThread(), which is -- non-blocking, so we can make this call unsafe. foreign import ccall unsafe "sendOnewayReq" c_sendOnewayReq :: Ptr CppWrappedChannel -> CString -> CSize -> CInt -> StablePtr PrimMVar -> Ptr CFinishedRequest -- RPC Options -> CString -> CSize -> IO ()