#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------

# pylint: disable=super-init-not-called,no-self-use

import datetime
import logging
import time

from uamqp import Session, c_uamqp, compat, constants, errors, utils
from uamqp.constants import TransportType

from .common import _SASL, AMQPAuth

_logger = logging.getLogger(__name__)


class TokenRetryPolicy(object):
    """Retry policy for sending authentication tokens
    for CBS authentication.

    :param retries: The number of retry attempts for a failed
     PUT token request. The default is 3. This is exclusive of
     the initial attempt.
    :type retries: int
    :param backoff: The time in miliseconds to wait between
     retry attempts.
    :type backoff: int
    """

    def __init__(self, retries=3, backoff=0):
        self.retries = retries
        self.backoff = float(backoff)/1000


class CBSAuthMixin(object):
    """Mixin to handle sending and refreshing CBS auth tokens."""

    def update_token(self):
        """Update a token that is about to expire. This is specific
        to a particular token type, and therefore must be implemented
        in a child class.
        """
        raise errors.TokenExpired(
            "Unable to refresh token - no refresh logic implemented.")

    def create_authenticator(self, connection, debug=False, **kwargs):
        """Create the AMQP session and the CBS channel with which
        to negotiate the token.

        :param connection: The underlying AMQP connection on which
         to create the session.
        :type connection: ~uamqp.connection.Connection
        :param debug: Whether to emit network trace logging events for the
         CBS session. Default is `False`. Logging events are set at INFO level.
        :type debug: bool
        :rtype: uamqp.c_uamqp.CBSTokenAuth
        """
        self._connection = connection
        self._session = Session(connection, **kwargs)

        try:
            self._cbs_auth = c_uamqp.CBSTokenAuth(
                self.audience,
                self.token_type,
                self.token,
                int(self.expires_at),
                self._session._session,  # pylint: disable=protected-access
                self.timeout,
                self._connection.container_id,
                self._refresh_window
            )
            self._cbs_auth.set_trace(debug)
        except ValueError:
            self._session.destroy()
            raise errors.AMQPConnectionError(
                "Unable to open authentication session on connection {}.\n"
                "Please confirm target hostname exists: {}".format(connection.container_id, connection.hostname))
        return self._cbs_auth

    def close_authenticator(self):
        """Close the CBS auth channel and session."""
        _logger.info("Shutting down CBS session on connection: %r.", self._connection.container_id)
        try:
            _logger.debug("Unlocked CBS to close on connection: %r.", self._connection.container_id)
            self._cbs_auth.destroy()
            _logger.info("Auth closed, destroying session on connection: %r.", self._connection.container_id)
            self._session.destroy()
        finally:
            _logger.info("Finished shutting down CBS session on connection: %r.", self._connection.container_id)

    def handle_token(self):
        """This function is called periodically to check the status of the current
        token if there is one, and request a new one if needed.
        If the token request fails, it will be retried according to the retry policy.
        A token refresh will be attempted if the token will expire soon.

        This function will return a tuple of two booleans. The first represents whether
        the token authentication has not completed within it's given timeout window. The
        second indicates whether the token negotiation is still in progress.

        :raises: ~uamqp.errors.AuthenticationException if the token authentication fails.
        :raises: ~uamqp.errors.TokenExpired if the token has expired and cannot be
         refreshed.
        :rtype: tuple[bool, bool]
        """
        # pylint: disable=protected-access
        timeout = False
        in_progress = False
        try:
            self._connection.lock()
            if self._connection._closing or self._connection._error:
                return timeout, in_progress
            auth_status = self._cbs_auth.get_status()
            auth_status = constants.CBSAuthStatus(auth_status)
            if auth_status == constants.CBSAuthStatus.Error:
                if self.retries >= self._retry_policy.retries:  # pylint: disable=no-member
                    _logger.warning("Authentication Put-Token failed. Retries exhausted.")
                    raise errors.TokenAuthFailure(*self._cbs_auth.get_failure_info())
                error_code, error_description = self._cbs_auth.get_failure_info()
                _logger.info("Authentication status: %r, description: %r", error_code, error_description)
                _logger.info("Authentication Put-Token failed. Retrying.")
                self.retries += 1  # pylint: disable=no-member
                time.sleep(self._retry_policy.backoff)
                self._cbs_auth.authenticate()
                in_progress = True
            elif auth_status == constants.CBSAuthStatus.Failure:
                raise errors.AuthenticationException("Failed to open CBS authentication link.")
            elif auth_status == constants.CBSAuthStatus.Expired:
                raise errors.TokenExpired("CBS Authentication Expired.")
            elif auth_status == constants.CBSAuthStatus.Timeout:
                timeout = True
            elif auth_status == constants.CBSAuthStatus.InProgress:
                in_progress = True
            elif auth_status == constants.CBSAuthStatus.RefreshRequired:
                _logger.info("Token on connection %r will expire soon - attempting to refresh.",
                             self._connection.container_id)
                self.update_token()
                if self.token != self._prev_token:
                    self._cbs_auth.refresh(self.token, int(self.expires_at))
                else:
                    _logger.info(
                        "The newly acquired token on connection %r is the same as the previous one,"
                        " will keep attempting to refresh",
                        self._connection.container_id
                    )
            elif auth_status == constants.CBSAuthStatus.Idle:
                self._cbs_auth.authenticate()
                in_progress = True
            elif auth_status != constants.CBSAuthStatus.Ok:
                raise errors.AuthenticationException("Invalid auth state.")
        except compat.TimeoutException:
            _logger.debug("CBS auth timed out while waiting for lock acquisition.")
            return None, None
        except ValueError as e:
            raise errors.AuthenticationException(
                "Token authentication failed: {}".format(e))
        finally:
            self._connection.release()
        return timeout, in_progress

    def _set_expiry(self, expires_at, expires_in):
        if not expires_at and not expires_in:
            raise ValueError("Must specify either 'expires_at' or 'expires_in'.")
        if not expires_at:
            expires_at = time.time() + expires_in.seconds
        else:
            expires_in_seconds = expires_at - time.time()
            if expires_in_seconds < 1:
                raise ValueError("Token has already expired.")
            expires_in = datetime.timedelta(seconds=expires_in_seconds)
        return expires_at, expires_in


class SASTokenAuth(AMQPAuth, CBSAuthMixin):
    """CBS authentication using SAS tokens.

    :param audience: The token audience field. For SAS tokens
     this is usually the URI.
    :type audience: str or bytes
    :param uri: The AMQP endpoint URI. This must be provided as
     a decoded string.
    :type uri: str
    :param token: The SAS token.
    :type token: str or bytes.
    :param expires_in: The total remaining seconds until the token
     expires.
    :type expires_in: ~datetime.timedelta
    :param expires_at: The timestamp at which the SAS token will expire
     formatted as seconds since epoch.
    :type expires_at: float
    :param username: The SAS token username, also referred to as the key
     name or policy name. This can optionally be encoded into the URI.
    :type username: str
    :param password: The SAS token password, also referred to as the key.
     This can optionally be encoded into the URI.
    :type password: str
    :param port: The TLS port - default for AMQP is 5671.
    :type port: int
    :param timeout: The timeout in seconds in which to negotiate the token.
     The default value is 10 seconds.
    :type timeout: float
    :param retry_policy: The retry policy for the PUT token request. The default
     retry policy has 3 retries.
    :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
    :param verify: The path to a user-defined certificate.
    :type verify: str
    :param token_type: The type field of the token request.
     Default value is `b"servicebus.windows.net:sastoken"`.
    :type token_type: bytes
    :param http_proxy: HTTP proxy configuration. This should be a dictionary with
     the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
     keys are 'username' and 'password'.
    :type http_proxy: dict
    :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
     ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
     transport type is explicitly requested.
    :type transport_type: ~uamqp.TransportType
    :param encoding: The encoding to use if hostname is provided as a str.
     Default is 'UTF-8'.
    :type encoding: str
    :keyword int refresh_window: The time in seconds before the token expiration
     time to start the process of token refresh.
     Default value is 10% of the remaining seconds until the token expires.
    """

    def __init__(self, audience, uri, token,
                 expires_in=None,
                 expires_at=None,
                 username=None,
                 password=None,
                 port=None,
                 timeout=10,
                 retry_policy=TokenRetryPolicy(),
                 verify=None,
                 token_type=b"servicebus.windows.net:sastoken",
                 http_proxy=None,
                 transport_type=TransportType.Amqp,
                 encoding='UTF-8',
                 **kwargs):  # pylint: disable=no-member
        self._retry_policy = retry_policy
        self._encoding = encoding
        self._refresh_window = kwargs.pop("refresh_window", 0)
        self._prev_token = None
        self.uri = uri
        parsed = compat.urlparse(uri)  # pylint: disable=no-member

        self.cert_file = verify
        self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding)
        self.username = compat.unquote_plus(parsed.username) if parsed.username else None  # pylint: disable=no-member
        self.password = compat.unquote_plus(parsed.password) if parsed.password else None  # pylint: disable=no-member

        self.username = username or self.username
        self.password = password or self.password
        self.audience = self._encode(audience)
        self.token_type = self._encode(token_type)
        self.token = self._encode(token)
        self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in)
        self.timeout = timeout
        self.retries = 0
        self.sasl = _SASL()
        self.set_io(self.hostname, port, http_proxy, transport_type)

    def update_token(self):
        """If a username and password are present - attempt to use them to
        request a fresh SAS token.
        """
        if not self.username or not self.password:
            raise errors.TokenExpired("Unable to refresh token - no username or password.")
        encoded_uri = compat.quote_plus(self.uri).encode(self._encoding)  # pylint: disable=no-member
        encoded_key = compat.quote_plus(self.username).encode(self._encoding)  # pylint: disable=no-member
        self.expires_at = time.time() + self.expires_in.seconds
        self._prev_token = self.token
        self.token = utils.create_sas_token(
            encoded_key,
            self.password.encode(self._encoding),
            encoded_uri,
            self.expires_in)

    @classmethod
    def from_shared_access_key(
            cls,
            uri,
            key_name,
            shared_access_key,
            expiry=None,
            port=None,
            timeout=10,
            retry_policy=TokenRetryPolicy(),
            verify=None,
            http_proxy=None,
            transport_type=TransportType.Amqp,
            encoding='UTF-8',
            **kwargs):
        """Attempt to create a CBS token session using a Shared Access Key such
        as is used to connect to Azure services.

        :param uri: The AMQP endpoint URI. This must be provided as
         a decoded string.
        :type uri: str
        :param key_name: The SAS token username, also referred to as the key
         name or policy name.
        :type key_name: str
        :param shared_access_key: The SAS token password, also referred to as the key.
        :type shared_access_key: str
        :param expiry: The lifetime in seconds for the generated token. Default is 1 hour.
        :type expiry: int
        :param port: The TLS port - default for AMQP is 5671.
        :type port: int
        :param timeout: The timeout in seconds in which to negotiate the token.
         The default value is 10 seconds.
        :type timeout: float
        :param retry_policy: The retry policy for the PUT token request. The default
         retry policy has 3 retries.
        :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
        :param verify: The path to a user-defined certificate.
        :type verify: str
        :param http_proxy: HTTP proxy configuration. This should be a dictionary with
         the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
         keys are 'username' and 'password'.
        :type http_proxy: dict
        :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
         ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
         transport type is explicitly requested.
        :type transport_type: ~uamqp.TransportType
        :param encoding: The encoding to use if hostname is provided as a str.
         Default is 'UTF-8'.
        :type encoding: str
        :keyword int refresh_window: The time in seconds before the token expiration
        time to start the process of token refresh.
        Default value is 10% of the remaining seconds until the token expires.
        """
        expires_in = datetime.timedelta(seconds=expiry or constants.AUTH_EXPIRATION_SECS)
        encoded_uri = compat.quote_plus(uri).encode(encoding)  # pylint: disable=no-member
        encoded_key = compat.quote_plus(key_name).encode(encoding)  # pylint: disable=no-member
        expires_at = time.time() + expires_in.seconds
        token = utils.create_sas_token(
            encoded_key,
            shared_access_key.encode(encoding),
            encoded_uri,
            expires_in)
        return cls(
            uri, uri, token,
            expires_in=expires_in,
            expires_at=expires_at,
            username=key_name,
            password=shared_access_key,
            port=port,
            timeout=timeout,
            retry_policy=retry_policy,
            verify=verify,
            http_proxy=http_proxy,
            transport_type=transport_type,
            encoding=encoding,
            custom_endpoint_hostname=kwargs.pop("custom_endpoint_hostname", None))


class JWTTokenAuth(AMQPAuth, CBSAuthMixin):
    """CBS authentication using JWT tokens.

    :param audience: The token audience field. For JWT tokens
     this is usually the URI.
    :type audience: str or bytes
    :param uri: The AMQP endpoint URI. This must be provided as
     a decoded string.
    :type uri: str
    :param get_token: The callback function used for getting and refreshing
     tokens. It should return a valid jwt token each time it is called.
    :type get_token: callable object
    :param expires_in: The total remaining seconds until the token
     expires - default for JWT token generated by AAD is 3600s (1 hour).
    :type expires_in: ~datetime.timedelta
    :param expires_at: The timestamp at which the JWT token will expire
     formatted as seconds since epoch.
    :type expires_at: float
    :param port: The TLS port - default for AMQP is 5671.
    :type port: int
    :param timeout: The timeout in seconds in which to negotiate the token.
     The default value is 10 seconds.
    :type timeout: float
    :param retry_policy: The retry policy for the PUT token request. The default
     retry policy has 3 retries.
    :type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
    :param verify: The path to a user-defined certificate.
    :type verify: str
    :param token_type: The type field of the token request.
     Default value is `b"jwt"`.
    :type token_type: bytes
    :param http_proxy: HTTP proxy configuration. This should be a dictionary with
     the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
     keys are 'username' and 'password'.
    :type http_proxy: dict
    :param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
     ~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
     transport type is explicitly requested.
    :type transport_type: ~uamqp.TransportType
    :param encoding: The encoding to use if hostname is provided as a str.
     Default is 'UTF-8'.
    :type encoding: str
    :keyword int refresh_window: The time in seconds before the token expiration
     time to start the process of token refresh.
     Default value is 10% of the remaining seconds until the token expires.
    """

    def __init__(self, audience, uri,
                 get_token,
                 expires_in=datetime.timedelta(seconds=constants.AUTH_EXPIRATION_SECS),
                 expires_at=None,
                 port=None,
                 timeout=10,
                 retry_policy=TokenRetryPolicy(),
                 verify=None,
                 token_type=b"jwt",
                 http_proxy=None,
                 transport_type=TransportType.Amqp,
                 encoding='UTF-8',
                 **kwargs):  # pylint: disable=no-member
        self._retry_policy = retry_policy
        self._encoding = encoding
        self._refresh_window = kwargs.pop("refresh_window", 0)
        self._prev_token = None
        self.uri = uri
        parsed = compat.urlparse(uri)  # pylint: disable=no-member

        self.cert_file = verify
        self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding)

        if not get_token or not callable(get_token):
            raise ValueError("get_token must be a callable object.")

        self.get_token = get_token
        self.audience = self._encode(audience)
        self.token_type = self._encode(token_type)
        self.token = None
        self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in)
        self.timeout = timeout
        self.retries = 0
        self.sasl = _SASL()
        self.set_io(self.hostname, port, http_proxy, transport_type)

    def create_authenticator(self, connection, debug=False, **kwargs):
        self.update_token()
        return super(JWTTokenAuth, self).create_authenticator(connection, debug, **kwargs)

    def update_token(self):
        access_token = self.get_token()
        self.expires_at = access_token.expires_on
        self._prev_token = self.token
        self.token = self._encode(access_token.token)
