fbnet/command_runner/thrift_client.py (64 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
from thrift.server.TAsyncioServer import ThriftClientProtocolFactory
from .base_service import ServiceObj
class AsyncioThriftClient(ServiceObj):
"""
util class to get asyncio client for different services using asyncio
get_hosts
"""
_TIMEOUT = 60 # By default timeout after 60s
def __init__(
self, client_class, host, port, service=None, timeout=None, open_timeout=None
):
super().__init__(service)
self._client_class = client_class
self._host = host
self._port = port
self._connected = False
self._timeout = timeout
self._open_timeout = open_timeout
self._protocol = None
self._transport = None
self._client = None
if self.service:
self._register_counter("connected")
self._register_counter("lookup.failed")
def _format_counter(self, counter):
return "thrift_client.{}.{}.{}".format(self._host, self._port, counter)
def _inc_counter(self, counter):
if self.service:
c = self._format_counter(counter)
self.inc_counter(c)
def _register_counter(self, counter):
c = self._format_counter(counter)
self.service.stats_mgr.register_counter(c)
async def _lookup_service(self):
return self._host, self._port
async def _get_timeouts(self):
"""Set the timeout for thrift calls"""
return {"": self._timeout or self._TIMEOUT}
async def open(self):
host, port = await self._lookup_service()
timeouts = await self._get_timeouts()
conn_fut = self.loop.create_connection(
ThriftClientProtocolFactory(self._client_class, timeouts=timeouts),
host=host,
port=port,
)
(transport, protocol) = await asyncio.wait_for(
conn_fut, self._open_timeout, loop=self.loop
)
self._inc_counter("connected")
self._protocol = protocol
self._transport = transport
self._client = protocol.client
# hookup the close method to the client
self._client.close = self.close
self._connected = True
return self._client
def close(self):
if self._protocol:
self._protocol.close()
if self._transport:
self._transport.close()
def __await__(self):
return self.open().__await__()
async def __aenter__(self):
await self.open()
return self._client
async def __aexit__(self, exc_type, exc, tb):
self.close()