stream/clients/python/bookkeeper/common/grpc_helpers.py (137 lines of code) (raw):
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for :mod:`grpc`."""
import collections
import grpc
import six
from bookkeeper.common import exceptions
from bookkeeper.common import general_helpers
# The list of gRPC Callable interfaces that return iterators.
_STREAM_WRAP_CLASSES = (
grpc.UnaryStreamMultiCallable,
grpc.StreamStreamMultiCallable,
)
def _patch_callable_name(callable_):
"""Fix-up gRPC callable attributes.
gRPC callable lack the ``__name__`` attribute which causes
:func:`functools.wraps` to error. This adds the attribute if needed.
"""
if not hasattr(callable_, '__name__'):
callable_.__name__ = callable_.__class__.__name__
def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary and Stream-Unary gRPC callables."""
_patch_callable_name(callable_)
@six.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
try:
return callable_(*args, **kwargs)
except grpc.RpcError as exc:
six.raise_from(exceptions.from_grpc_error(exc), exc)
return error_remapped_callable
class _StreamingResponseIterator(grpc.Call):
def __init__(self, wrapped):
self._wrapped = wrapped
def __iter__(self):
"""This iterator is also an iterable that returns itself."""
return self
def next(self):
"""Get the next response from the stream.
Returns:
protobuf.Message: A single response from the stream.
"""
try:
return six.next(self._wrapped)
except grpc.RpcError as exc:
six.raise_from(exceptions.from_grpc_error(exc), exc)
# Alias needed for Python 2/3 support.
__next__ = next
# grpc.Call & grpc.RpcContext interface
def add_callback(self, callback):
return self._wrapped.add_callback(callback)
def cancel(self):
return self._wrapped.cancel()
def code(self):
return self._wrapped.code()
def details(self):
return self._wrapped.details()
def initial_metadata(self):
return self._wrapped.initial_metadata()
def is_active(self):
return self._wrapped.is_active()
def time_remaining(self):
return self._wrapped.time_remaining()
def trailing_metadata(self):
return self._wrapped.trailing_metadata()
def _wrap_stream_errors(callable_):
"""Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
The callables that return iterators require a bit more logic to re-map
errors when iterating. This wraps both the initial invocation and the
iterator of the return value to re-map errors.
"""
_patch_callable_name(callable_)
@general_helpers.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
try:
result = callable_(*args, **kwargs)
return _StreamingResponseIterator(result)
except grpc.RpcError as exc:
six.raise_from(exceptions.from_grpc_error(exc), exc)
return error_remapped_callable
def wrap_errors(callable_):
"""Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
classes.
Errors raised by the gRPC callable are mapped to the appropriate
:class:`bookkeeper.common.exceptions.BKGrpcAPICallError` subclasses.
The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
available from the ``response`` property on the mapped exception. This
is useful for extracting metadata from the original error.
Args:
callable_ (Callable): A gRPC callable.
Returns:
Callable: The wrapped gRPC callable.
"""
if isinstance(callable_, _STREAM_WRAP_CLASSES):
return _wrap_stream_errors(callable_)
else:
return _wrap_unary_errors(callable_)
def create_channel(target,
**kwargs):
"""Create a secure channel with credentials.
Args:
target (str): The target service address in the format 'hostname:port'.
kwargs: Additional key-word args passed to
:func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`.
Returns:
grpc.Channel: The created channel.
"""
return grpc.secure_channel(target, None, **kwargs)
_MethodCall = collections.namedtuple(
'_MethodCall', ('request', 'timeout', 'metadata', 'credentials'))
_ChannelRequest = collections.namedtuple(
'_ChannelRequest', ('method', 'request'))
class _CallableStub(object):
"""Stub for the grpc.*MultiCallable interfaces."""
def __init__(self, method, channel):
self._method = method
self._channel = channel
self.response = None
"""Union[protobuf.Message, Callable[protobuf.Message], exception]:
The response to give when invoking this callable. If this is a
callable, it will be invoked with the request protobuf. If it's an
exception, the exception will be raised when this is invoked.
"""
self.responses = None
"""Iterator[
Union[protobuf.Message, Callable[protobuf.Message], exception]]:
An iterator of responses. If specified, self.response will be populated
on each invocation by calling ``next(self.responses)``."""
self.requests = []
"""List[protobuf.Message]: All requests sent to this callable."""
self.calls = []
"""List[Tuple]: All invocations of this callable. Each tuple is the
request, timeout, metadata, and credentials."""
def __call__(self, request, timeout=None, metadata=None, credentials=None):
self._channel.requests.append(
_ChannelRequest(self._method, request))
self.calls.append(
_MethodCall(request, timeout, metadata, credentials))
self.requests.append(request)
response = self.response
if self.responses is not None:
if response is None:
response = next(self.responses)
else:
raise ValueError(
'{method}.response and {method}.responses are mutually '
'exclusive.'.format(method=self._method))
if callable(response):
return response(request)
if isinstance(response, Exception):
raise response
if response is not None:
return response
raise ValueError(
'Method stub for "{}" has no response.'.format(self._method))
def _simplify_method_name(method):
"""Simplifies a gRPC method name.
When gRPC invokes the channel to create a callable, it gives a full
method name like "/org.apache.bookkeeper.Table/Put". This
returns just the name of the method, in this case "Put".
Args:
method (str): The name of the method.
Returns:
str: The simplified name of the method.
"""
return method.rsplit('/', 1).pop()
class ChannelStub(grpc.Channel):
"""A testing stub for the grpc.Channel interface.
This can be used to test any client that eventually uses a gRPC channel
to communicate. By passing in a channel stub, you can configure which
responses are returned and track which requests are made.
For example:
.. code-block:: python
channel_stub = grpc_helpers.ChannelStub()
client = FooClient(channel=channel_stub)
channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
foo = client.get_foo(labels=['baz'])
assert foo.name == 'bar'
assert channel_stub.GetFoo.requests[0].labels = ['baz']
Each method on the stub can be accessed and configured on the channel.
Here's some examples of various configurations:
.. code-block:: python
# Return a basic response:
channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
assert client.get_foo().name == 'bar'
# Raise an exception:
channel_stub.GetFoo.response = NotFound('...')
with pytest.raises(NotFound):
client.get_foo()
# Use a sequence of responses:
channel_stub.GetFoo.responses = iter([
foo_pb2.Foo(name='bar'),
foo_pb2.Foo(name='baz'),
])
assert client.get_foo().name == 'bar'
assert client.get_foo().name == 'baz'
# Use a callable
def on_get_foo(request):
return foo_pb2.Foo(name='bar' + request.id)
channel_stub.GetFoo.response = on_get_foo
assert client.get_foo(id='123').name == 'bar123'
"""
def __init__(self, responses=[]):
self.requests = []
"""Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
on this channel in order. The tuple is of method name, request
message."""
self._method_stubs = {}
def _stub_for_method(self, method):
method = _simplify_method_name(method)
self._method_stubs[method] = _CallableStub(method, self)
return self._method_stubs[method]
def __getattr__(self, key):
try:
return self._method_stubs[key]
except KeyError:
raise AttributeError
def unary_unary(
self, method,
request_serializer=None, response_deserializer=None):
"""grpc.Channel.unary_unary implementation."""
return self._stub_for_method(method)
def unary_stream(
self, method,
request_serializer=None, response_deserializer=None):
"""grpc.Channel.unary_stream implementation."""
return self._stub_for_method(method)
def stream_unary(
self, method,
request_serializer=None, response_deserializer=None):
"""grpc.Channel.stream_unary implementation."""
return self._stub_for_method(method)
def stream_stream(
self, method,
request_serializer=None, response_deserializer=None):
"""grpc.Channel.stream_stream implementation."""
return self._stub_for_method(method)
def subscribe(self, callback, try_to_connect=False):
"""grpc.Channel.subscribe implementation."""
pass
def unsubscribe(self, callback):
"""grpc.Channel.unsubscribe implementation."""
pass
def close(self):
"""grpc.Channel.close implementation."""
pass