#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

# pylint: disable=super-init-not-called,no-self-use

import asyncio
import datetime
import logging

from uamqp import c_uamqp, compat, constants, errors
from uamqp.async_ops import SessionAsync
from uamqp.constants import TransportType
from uamqp.async_ops.utils import get_dict_with_loop_if_needed

from .cbs_auth import CBSAuthMixin, SASTokenAuth, JWTTokenAuth, TokenRetryPolicy

from .common import _SASL


_logger = logging.getLogger(__name__)


def is_coroutine(get_token):
    try:
        if asyncio.iscoroutinefunction(get_token.func):
            return True
    except AttributeError:
        if asyncio.iscoroutinefunction(get_token):
            return True
    raise ValueError("get_token must be a coroutine function")


class CBSAsyncAuthMixin(CBSAuthMixin):
    """Mixin to handle sending and refreshing CBS auth tokens asynchronously."""

    @property
    def loop(self):
        return self._internal_kwargs.get("loop")

    async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs):
        """Create the async AMQP session and the CBS channel with which
        to negotiate the token.

        :param connection: The underlying AMQP connection on which
         to create the session.
        :type connection: ~uamqp.async_ops.connection_async.ConnectionAsync
        :param debug: Whether to emit network trace logging events for the
         CBS session. Default is `False`. Logging events are set at INFO level.
        :type debug: bool
        :rtype: uamqp.c_uamqp.CBSTokenAuth
        """
        self._internal_kwargs = get_dict_with_loop_if_needed(loop)
        self._connection = connection
        kwargs.update(self._internal_kwargs)
        self._session = SessionAsync(connection, **kwargs)

        try:
            self._cbs_auth = c_uamqp.CBSTokenAuth(
                self.audience,
                self.token_type,
                self.token,
                int(self.expires_at),
                self._session._session,  # pylint: disable=protected-access
                self.timeout,
                self._connection.container_id,
                self._refresh_window
            )
            self._cbs_auth.set_trace(debug)
        except ValueError:
            await self._session.destroy_async()
            raise errors.AMQPConnectionError(
                "Unable to open authentication session on connection {}.\n"
                "Please confirm target hostname exists: {}".format(
                    connection.container_id, connection.hostname)) from None
        return self._cbs_auth

    async def close_authenticator_async(self):
        """Close the CBS auth channel and session asynchronously."""
        _logger.info("Shutting down CBS session on connection: %r.", self._connection.container_id)
        try:
            self._cbs_auth.destroy()
            _logger.info("Auth closed, destroying session on connection: %r.", self._connection.container_id)
            await self._session.destroy_async()
        finally:
            _logger.info("Finished shutting down CBS session on connection: %r.", self._connection.container_id)

    async def handle_token_async(self):
        """This coroutine is called periodically to check the status of the current
        token if there is one, and request a new one if needed.
        If the token request fails, it will be retried according to the retry policy.
        A token refresh will be attempted if the token will expire soon.

        This function will return a tuple of two booleans. The first represents whether
        the token authentication has not completed within it's given timeout window. The
        second indicates whether the token negotiation is still in progress.

        :raises: ~uamqp.errors.AuthenticationException if the token authentication fails.
        :raises: ~uamqp.errors.TokenExpired if the token has expired and cannot be
         refreshed.
        :rtype: tuple[bool, bool]
        """
        # pylint: disable=protected-access
        timeout = False
        in_progress = False
        try:
            await self._connection.lock_async()
            if self._connection._closing or self._connection._error:
                return timeout, in_progress
            auth_status = self._cbs_auth.get_status()
            auth_status = constants.CBSAuthStatus(auth_status)
            if auth_status == constants.CBSAuthStatus.Error:
                if self.retries >= self._retry_policy.retries:  # pylint: disable=no-member
                    _logger.warning("Authentication Put-Token failed. Retries exhausted.")
                    raise errors.TokenAuthFailure(*self._cbs_auth.get_failure_info())
                error_code, error_description = self._cbs_auth.get_failure_info()
                _logger.info("Authentication status: %r, description: %r", error_code, error_description)
                _logger.info("Authentication Put-Token failed. Retrying.")
                self.retries += 1  # pylint: disable=no-member
                await asyncio.sleep(self._retry_policy.backoff, **self._internal_kwargs)
                self._cbs_auth.authenticate()
                in_progress = True
            elif auth_status == constants.CBSAuthStatus.Failure:
                raise errors.AuthenticationException("Failed to open CBS authentication link.")
            elif auth_status == constants.CBSAuthStatus.Expired:
                raise errors.TokenExpired("CBS Authentication Expired.")
            elif auth_status == constants.CBSAuthStatus.Timeout:
                timeout = True
            elif auth_status == constants.CBSAuthStatus.InProgress:
                in_progress = True
            elif auth_status == constants.CBSAuthStatus.RefreshRequired:
                _logger.info("Token on connection %r will expire soon - attempting to refresh.",
                             self._connection.container_id)
                await self.update_token()
                if self.token != self._prev_token:
                    self._cbs_auth.refresh(self.token, int(self.expires_at))
                else:
                    _logger.info(
                        "The newly acquired token on connection %r is the same as the previous one,"
                        " will keep attempting to refresh",
                        self._connection.container_id
                    )
            elif auth_status == constants.CBSAuthStatus.Idle:
                self._cbs_auth.authenticate()
                in_progress = True
            elif auth_status != constants.CBSAuthStatus.Ok:
                raise ValueError("Invalid auth state.")
        except asyncio.TimeoutError:
            _logger.debug("CBS auth timed out while waiting for lock acquisition.")
            return None, None
        except ValueError as e:
            raise errors.AuthenticationException(
                "Token authentication failed: {}".format(e))
        finally:
            self._connection.release_async()
        return timeout, in_progress


class SASTokenAsync(SASTokenAuth, CBSAsyncAuthMixin):
    """Asynchronous CBS authentication using SAS tokens.

    :param audience: The token audience field. For SAS tokens
     this is usually the URI.
    :type audience: str or bytes
    :param uri: The AMQP endpoint URI. This must be provided as
     a decoded string.
    :type uri: str
    :param token: The SAS token.
    :type token: str or bytes.
    :param expires_in: The total remaining seconds until the token
     expires.
    :type expires_in: ~datetime.timedelta
    :param expires_at: The timestamp at which the SAS token will expire
     formatted as seconds since epoch.
    :type expires_at: float
    :param username: The SAS token username, also referred to as the key
     name or policy name. This can optionally be encoded into the URI.
    :type username: str
    :param password: The SAS token password, also referred to as the key.
     This can optionally be encoded into the URI.
    :type password: str
    :param port: The TLS port - default for AMQP is 5671.
    :type port: int
    :param timeout: The timeout in seconds in which to negotiate the token.
     The default value is 10 seconds.
    :type timeout: float
    :param retry_policy: The retry policy for the PUT token request. The default
     retry policy has 3 retries.
    :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
    :param verify: The path to a user-defined certificate.
    :type verify: str
    :param token_type: The type field of the token request.
     Default value is `b"servicebus.windows.net:sastoken"`.
    :type token_type: bytes
    :param http_proxy: HTTP proxy configuration. This should be a dictionary with
     the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
     keys are 'username' and 'password'.
    :type http_proxy: dict
    :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
     ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
     transport type is explicitly requested.
    :type transport_type: ~uamqp.TransportType
    :param encoding: The encoding to use if hostname is provided as a str.
     Default is 'UTF-8'.
    :type encoding: str
    :keyword int refresh_window: The time in seconds before the token expiration
     time to start the process of token refresh.
     Default value is 10% of the remaining seconds until the token expires.
    """
    async def update_token(self):  # pylint: disable=useless-super-delegation
        super(SASTokenAsync, self).update_token()


class JWTTokenAsync(JWTTokenAuth, CBSAsyncAuthMixin):
    """CBS authentication using JWT tokens.

    :param audience: The token audience field. For JWT tokens
     this is usually the URI.
    :type audience: str or bytes
    :param uri: The AMQP endpoint URI. This must be provided as
     a decoded string.
    :type uri: str
    :param get_token: The callback function used for getting and refreshing
     tokens. It should return a valid jwt token each time it is called.
    :type get_token: coroutine function
    :param expires_in: The total remaining seconds until the token
     expires - default for JWT token generated by AAD is 3600s (1 hour).
    :type expires_in: ~datetime.timedelta
    :param expires_at: The timestamp at which the JWT token will expire
     formatted as seconds since epoch.
    :type expires_at: float
    :param port: The TLS port - default for AMQP is 5671.
    :type port: int
    :param timeout: The timeout in seconds in which to negotiate the token.
     The default value is 10 seconds.
    :type timeout: float
    :param retry_policy: The retry policy for the PUT token request. The default
     retry policy has 3 retries.
    :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
    :param verify: The path to a user-defined certificate.
    :type verify: str
    :param token_type: The type field of the token request.
     Default value is `b"jwt"`.
    :type token_type: bytes
    :param http_proxy: HTTP proxy configuration. This should be a dictionary with
     the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
     keys are 'username' and 'password'.
    :type http_proxy: dict
    :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
     ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
     transport type is explicitly requested.
    :type transport_type: ~uamqp.TransportType
    :param encoding: The encoding to use if hostname is provided as a str.
     Default is 'UTF-8'.
    :type encoding: str
    :keyword int refresh_window: The time in seconds before the token expiration
     time to start the process of token refresh.
     Default value is 10% of the remaining seconds until the token expires.
    """

    def __init__(self, audience, uri,
                 get_token,
                 expires_in=datetime.timedelta(seconds=constants.AUTH_EXPIRATION_SECS),
                 expires_at=None,
                 port=None,
                 timeout=10,
                 retry_policy=TokenRetryPolicy(),
                 verify=None,
                 token_type=b"jwt",
                 http_proxy=None,
                 transport_type=TransportType.Amqp,
                 encoding='UTF-8',
                 **kwargs):  # pylint: disable=no-member
        self._retry_policy = retry_policy
        self._encoding = encoding
        self._refresh_window = kwargs.pop("refresh_window", 0)
        self._prev_token = None
        self.uri = uri
        parsed = compat.urlparse(uri)  # pylint: disable=no-member

        self.cert_file = verify
        self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding)

        is_coroutine(get_token)

        self.get_token = get_token
        self.audience = self._encode(audience)
        self.token_type = self._encode(token_type)
        self.token = None
        self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in)
        self.timeout = timeout
        self.retries = 0
        self.sasl = _SASL()
        self.set_io(self.hostname, port, http_proxy, transport_type)

    async def create_authenticator_async(self, connection, debug=False, loop=None, **kwargs):
        await self.update_token()
        return await super(JWTTokenAsync, self).create_authenticator_async(connection, debug, **kwargs)

    async def update_token(self):
        access_token = await self.get_token()
        self.expires_at = access_token.expires_on
        self._prev_token = self.token
        self.token = self._encode(access_token.token)
