uamqp/authentication/cbs_auth_async.py (151 lines of code) (raw):

#------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. #-------------------------------------------------------------------------- # pylint: disable=super-init-not-called,no-self-use import asyncio import datetime import logging from uamqp import c_uamqp, compat, constants, errors from uamqp.async_ops import SessionAsync from uamqp.constants import TransportType from uamqp.async_ops.utils import get_dict_with_loop_if_needed from .cbs_auth import CBSAuthMixin, SASTokenAuth, JWTTokenAuth, TokenRetryPolicy from .common import _SASL _logger = logging.getLogger(__name__) def is_coroutine(get_token): try: if asyncio.iscoroutinefunction(get_token.func): return True except AttributeError: if asyncio.iscoroutinefunction(get_token): return True raise ValueError("get_token must be a coroutine function") class CBSAsyncAuthMixin(CBSAuthMixin): """Mixin to handle sending and refreshing CBS auth tokens asynchronously.""" @property def loop(self): return self._internal_kwargs.get("loop") async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs): """Create the async AMQP session and the CBS channel with which to negotiate the token. :param connection: The underlying AMQP connection on which to create the session. :type connection: ~uamqp.async_ops.connection_async.ConnectionAsync :param debug: Whether to emit network trace logging events for the CBS session. Default is `False`. Logging events are set at INFO level. :type debug: bool :rtype: uamqp.c_uamqp.CBSTokenAuth """ self._internal_kwargs = get_dict_with_loop_if_needed(loop) self._connection = connection kwargs.update(self._internal_kwargs) self._session = SessionAsync(connection, **kwargs) try: self._cbs_auth = c_uamqp.CBSTokenAuth( self.audience, self.token_type, self.token, int(self.expires_at), self._session._session, # pylint: disable=protected-access self.timeout, self._connection.container_id, self._refresh_window ) self._cbs_auth.set_trace(debug) except ValueError: await self._session.destroy_async() raise errors.AMQPConnectionError( "Unable to open authentication session on connection {}.\n" "Please confirm target hostname exists: {}".format( connection.container_id, connection.hostname)) from None return self._cbs_auth async def close_authenticator_async(self): """Close the CBS auth channel and session asynchronously.""" _logger.info("Shutting down CBS session on connection: %r.", self._connection.container_id) try: self._cbs_auth.destroy() _logger.info("Auth closed, destroying session on connection: %r.", self._connection.container_id) await self._session.destroy_async() finally: _logger.info("Finished shutting down CBS session on connection: %r.", self._connection.container_id) async def handle_token_async(self): """This coroutine is called periodically to check the status of the current token if there is one, and request a new one if needed. If the token request fails, it will be retried according to the retry policy. A token refresh will be attempted if the token will expire soon. This function will return a tuple of two booleans. The first represents whether the token authentication has not completed within it's given timeout window. The second indicates whether the token negotiation is still in progress. :raises: ~uamqp.errors.AuthenticationException if the token authentication fails. :raises: ~uamqp.errors.TokenExpired if the token has expired and cannot be refreshed. :rtype: tuple[bool, bool] """ # pylint: disable=protected-access timeout = False in_progress = False try: await self._connection.lock_async() if self._connection._closing or self._connection._error: return timeout, in_progress auth_status = self._cbs_auth.get_status() auth_status = constants.CBSAuthStatus(auth_status) if auth_status == constants.CBSAuthStatus.Error: if self.retries >= self._retry_policy.retries: # pylint: disable=no-member _logger.warning("Authentication Put-Token failed. Retries exhausted.") raise errors.TokenAuthFailure(*self._cbs_auth.get_failure_info()) error_code, error_description = self._cbs_auth.get_failure_info() _logger.info("Authentication status: %r, description: %r", error_code, error_description) _logger.info("Authentication Put-Token failed. Retrying.") self.retries += 1 # pylint: disable=no-member await asyncio.sleep(self._retry_policy.backoff, **self._internal_kwargs) self._cbs_auth.authenticate() in_progress = True elif auth_status == constants.CBSAuthStatus.Failure: raise errors.AuthenticationException("Failed to open CBS authentication link.") elif auth_status == constants.CBSAuthStatus.Expired: raise errors.TokenExpired("CBS Authentication Expired.") elif auth_status == constants.CBSAuthStatus.Timeout: timeout = True elif auth_status == constants.CBSAuthStatus.InProgress: in_progress = True elif auth_status == constants.CBSAuthStatus.RefreshRequired: _logger.info("Token on connection %r will expire soon - attempting to refresh.", self._connection.container_id) await self.update_token() if self.token != self._prev_token: self._cbs_auth.refresh(self.token, int(self.expires_at)) else: _logger.info( "The newly acquired token on connection %r is the same as the previous one," " will keep attempting to refresh", self._connection.container_id ) elif auth_status == constants.CBSAuthStatus.Idle: self._cbs_auth.authenticate() in_progress = True elif auth_status != constants.CBSAuthStatus.Ok: raise ValueError("Invalid auth state.") except asyncio.TimeoutError: _logger.debug("CBS auth timed out while waiting for lock acquisition.") return None, None except ValueError as e: raise errors.AuthenticationException( "Token authentication failed: {}".format(e)) finally: self._connection.release_async() return timeout, in_progress class SASTokenAsync(SASTokenAuth, CBSAsyncAuthMixin): """Asynchronous CBS authentication using SAS tokens. :param audience: The token audience field. For SAS tokens this is usually the URI. :type audience: str or bytes :param uri: The AMQP endpoint URI. This must be provided as a decoded string. :type uri: str :param token: The SAS token. :type token: str or bytes. :param expires_in: The total remaining seconds until the token expires. :type expires_in: ~datetime.timedelta :param expires_at: The timestamp at which the SAS token will expire formatted as seconds since epoch. :type expires_at: float :param username: The SAS token username, also referred to as the key name or policy name. This can optionally be encoded into the URI. :type username: str :param password: The SAS token password, also referred to as the key. This can optionally be encoded into the URI. :type password: str :param port: The TLS port - default for AMQP is 5671. :type port: int :param timeout: The timeout in seconds in which to negotiate the token. The default value is 10 seconds. :type timeout: float :param retry_policy: The retry policy for the PUT token request. The default retry policy has 3 retries. :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy :param verify: The path to a user-defined certificate. :type verify: str :param token_type: The type field of the token request. Default value is `b"servicebus.windows.net:sastoken"`. :type token_type: bytes :param http_proxy: HTTP proxy configuration. This should be a dictionary with the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional keys are 'username' and 'password'. :type http_proxy: dict :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp. ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the transport type is explicitly requested. :type transport_type: ~uamqp.TransportType :param encoding: The encoding to use if hostname is provided as a str. Default is 'UTF-8'. :type encoding: str :keyword int refresh_window: The time in seconds before the token expiration time to start the process of token refresh. Default value is 10% of the remaining seconds until the token expires. """ async def update_token(self): # pylint: disable=useless-super-delegation super(SASTokenAsync, self).update_token() class JWTTokenAsync(JWTTokenAuth, CBSAsyncAuthMixin): """CBS authentication using JWT tokens. :param audience: The token audience field. For JWT tokens this is usually the URI. :type audience: str or bytes :param uri: The AMQP endpoint URI. This must be provided as a decoded string. :type uri: str :param get_token: The callback function used for getting and refreshing tokens. It should return a valid jwt token each time it is called. :type get_token: coroutine function :param expires_in: The total remaining seconds until the token expires - default for JWT token generated by AAD is 3600s (1 hour). :type expires_in: ~datetime.timedelta :param expires_at: The timestamp at which the JWT token will expire formatted as seconds since epoch. :type expires_at: float :param port: The TLS port - default for AMQP is 5671. :type port: int :param timeout: The timeout in seconds in which to negotiate the token. The default value is 10 seconds. :type timeout: float :param retry_policy: The retry policy for the PUT token request. The default retry policy has 3 retries. :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy :param verify: The path to a user-defined certificate. :type verify: str :param token_type: The type field of the token request. Default value is `b"jwt"`. :type token_type: bytes :param http_proxy: HTTP proxy configuration. This should be a dictionary with the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional keys are 'username' and 'password'. :type http_proxy: dict :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp. ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the transport type is explicitly requested. :type transport_type: ~uamqp.TransportType :param encoding: The encoding to use if hostname is provided as a str. Default is 'UTF-8'. :type encoding: str :keyword int refresh_window: The time in seconds before the token expiration time to start the process of token refresh. Default value is 10% of the remaining seconds until the token expires. """ def __init__(self, audience, uri, get_token, expires_in=datetime.timedelta(seconds=constants.AUTH_EXPIRATION_SECS), expires_at=None, port=None, timeout=10, retry_policy=TokenRetryPolicy(), verify=None, token_type=b"jwt", http_proxy=None, transport_type=TransportType.Amqp, encoding='UTF-8', **kwargs): # pylint: disable=no-member self._retry_policy = retry_policy self._encoding = encoding self._refresh_window = kwargs.pop("refresh_window", 0) self._prev_token = None self.uri = uri parsed = compat.urlparse(uri) # pylint: disable=no-member self.cert_file = verify self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding) is_coroutine(get_token) self.get_token = get_token self.audience = self._encode(audience) self.token_type = self._encode(token_type) self.token = None self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in) self.timeout = timeout self.retries = 0 self.sasl = _SASL() self.set_io(self.hostname, port, http_proxy, transport_type) async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs): await self.update_token() return await super(JWTTokenAsync, self).create_authenticator_async(connection, debug, **kwargs) async def update_token(self): access_token = await self.get_token() self.expires_at = access_token.expires_on self._prev_token = self.token self.token = self._encode(access_token.token)