#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import functools
import inspect
import random
import re
import sys
import typing
from dataclasses import dataclass
from functools import wraps
from itertools import islice
from uuid import uuid4

from fb303_asyncio.FacebookBase import FacebookBase
from fbnet.command_runner_asyncio.CommandRunner import constants, ttypes
from fbnet.command_runner_asyncio.CommandRunner.Command import Iface as FcrIface

from .command_session import CommandSession, CapturedTimeMS
from .counters import Counters
from .exceptions import ensure_thrift_exception, convert_to_fcr_exception
from .global_namespace import GlobalNamespace
from .options import Option
from .utils import input_fields_validator


@dataclass(frozen=True)
class DeviceResult:
    """
    Class for passing information received after sending commands to device
    Includes CommandResult(s), captured communication/processing time, and any raised exceptions
    (remains None for exceptions that are returned instead)
    """

    device_response: typing.Union[
        typing.List[typing.Optional[ttypes.CommandResult]],
        typing.Optional[ttypes.CommandResult],
    ] = None
    captured_time_ms: CapturedTimeMS = CapturedTimeMS()
    exception: typing.Optional[Exception] = None


def _append_debug_info_to_exception(fn):
    @wraps(fn)
    async def wrapper(self, *args, **kwargs):
        try:
            return await fn(self, *args, **kwargs)
        except Exception as ex:
            # Retrieve request uuid from the global namespace
            uuid = GlobalNamespace.get_request_uuid()
            # Exception defined in command runner thrift spec has attribute 'message'
            if hasattr(ex, "message"):
                ex.message = await self.add_debug_info_to_error_message(  # noqa
                    error_msg=ex.message, uuid=uuid  # noqa
                )
                raise ex
            else:
                # Don't pass in ex.message to debug info to avoid duplicate message
                debug_info = await self.add_debug_info_to_error_message(
                    error_msg="", uuid=uuid
                )
                # Append message as new arg instead of constructing new exception
                # to account for exceptions having different required args
                ex.args = ex.args + (debug_info,)
                raise ex.with_traceback(sys.exc_info()[2])

    return wrapper


def _ensure_uuid(fn):
    """Make sure the 'uuid' parameter for both input and return is non-empty"""

    @wraps(fn)
    async def wrapper(*args, **kwargs):
        uuid = ""
        callargs = inspect.getcallargs(fn, *args, **kwargs)
        if "uuid" in callargs:
            uuid = callargs["uuid"] or uuid4().hex[:8]
            callargs["uuid"] = uuid

        # Note: this won't work for functions that specify positional-only or
        # kwarg-only parameters.
        result = await fn(**callargs)

        # Set UUID on the resulting struct -- if it supports it: Map of, or
        # raw, CommandResult and Session
        if isinstance(result, (ttypes.CommandResult, ttypes.Session)):
            result.uuid = uuid
        elif isinstance(result, dict):
            for val in result.values():
                # If size of result is large, we can refactor to only check the
                # first element. (time taken: ~100ns * N)
                if not isinstance(val, ttypes.CommandResult):
                    break
                val.uuid = uuid

        return result

    return wrapper


class CommandHandler(Counters, FacebookBase, FcrIface):
    """
    Command implementation for api defined in thrift for command runner
    """

    _COUNTER_PREFIX = "fbnet.command_runner"
    _LB_THRESHOLD = 100

    REMOTE_CALL_OVERHEAD = Option(
        "--remote_call_overhead",
        help="Overhead for running commands remotely (for bulk calls)",
        type=int,
        default=20,
    )

    LB_THRESHOLD = Option(
        "--lb_threshold",
        help="""Load Balance threashold for bulk_run calls. If number of
        devices is greater than this threashold, the requests are broken and
        send to other instances using bulk_run_local() api""",
        type=int,
        default=100,
    )

    BULK_SESSION_LIMIT = Option(
        "--bulk_session_limit",
        help="""session limit above which we reject the bulk run local
        calls""",
        type=int,
        default=200,
    )

    BULK_RETRY_LIMIT = Option(
        "--bulk_retry_limit",
        help="""number of times to retry bulk call on the remote instances""",
        type=int,
        default=5,
    )

    BULK_RUN_JITTER = Option(
        "--bulk_run_jitter",
        help="""A random delay added for bulk commands to stagger the calls to
        distribute the load.""",
        type=int,
        default=5,
    )

    BULK_RETRY_DELAY_MIN = Option(
        "--bulk_retry_delay_min",
        help="""number of seconds to wait before retrying""",
        type=int,
        default=5,
    )

    BULK_RETRY_DELAY_MAX = Option(
        "--bulk_retry_delay_max",
        help="""number of seconds to wait before retrying""",
        type=int,
        default=10,
    )

    _bulk_session_count = 0

    def __init__(self, service, name=None):
        Counters.__init__(self, service, name)
        FacebookBase.__init__(self, service.app_name)
        FcrIface.__init__(self)

        self.service.register_stats_mgr(self)

    def cleanup(self):
        pass

    @classmethod
    def register_counters(cls, stats_mgr):
        stats_mgr.register_counter("bulk_run.remote")
        stats_mgr.register_counter("bulk_run.local")
        stats_mgr.register_counter("bulk_run.local.overload_error")

    def getCounters(self):
        ret = {}
        for key, value in self.counters.items():
            if not key.startswith(self._COUNTER_PREFIX):
                key = self._COUNTER_PREFIX + "." + key
            ret[key] = value() if callable(value) else value
        return ret

    @classmethod
    def _set_bulk_session_count(cls, new_count: int) -> None:
        """Method to set the class variable _bulk_session_count"""
        cls._bulk_session_count = new_count

    async def add_debug_info_to_error_message(self, uuid, error_msg=None):
        if not error_msg:
            return f"(DebugInfo: thrift_uuid={uuid})"
        return f"{error_msg} (DebugInfo: thrift_uuid={uuid})"

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def run(
        self, command, device, timeout, open_timeout, client_ip, client_port, uuid
    ) -> ttypes.CommandResult:
        result = await self._run_commands(
            [command], device, timeout, open_timeout, client_ip, client_port, uuid
        )

        GlobalNamespace.set_api_captured_time_ms(
            GlobalNamespace.get_api_captured_time_ms() + result.captured_time_ms
        )

        # Raise exception if needed here since _run_commands does not raise
        # and exceptions are not returned
        if isinstance(result.exception, Exception):
            raise result.exception

        cmd_result = result.device_response
        if isinstance(cmd_result, list):
            cmd_result = result.device_response[0]  # pyre-ignore checked is list

        return cmd_result

    def _bulk_failure(self, device_to_commands, message):
        def command_failures(cmds):
            return [
                ttypes.CommandResult(
                    output=message, status=constants.FAILURE_STATUS, command=cmd
                )
                for cmd in cmds
            ]

        return {
            self._get_result_key(dev): command_failures(cmds)
            for dev, cmds in device_to_commands.items()
        }

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def bulk_run_v2(
        self, request: ttypes.BulkRunCommandRequest
    ) -> ttypes.BulkRunCommandResponse:
        device_to_commands = {}
        for device_commands in request.device_commands_list:
            device_details = device_commands.device
            # Since this is in the handler, is it okay to have struct as a key?
            device_to_commands[device_details] = device_commands.commands

        result = await self.bulk_run(
            device_to_commands,
            request.timeout,
            request.open_timeout,
            request.client_ip,
            request.client_port,
            request.uuid,
        )
        response = ttypes.BulkRunCommandResponse()
        response.device_to_result = result
        return response

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def bulk_run(
        self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
    ) -> typing.Dict[str, typing.List[ttypes.CommandResult]]:
        if (len(device_to_commands) < self.LB_THRESHOLD) and (
            self._bulk_session_count < self.BULK_SESSION_LIMIT
        ):
            # Run these command locally.
            self.incrementCounter("bulk_run.local")
            return await self._bulk_run_local(
                device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
            )

        async def _remote_task(chunk):
            # Run the chunk of commands on remote instance
            self.incrementCounter("bulk_run.remote")
            retry_count = 0
            while True:
                try:
                    return await self._bulk_run_remote(
                        chunk, timeout, open_timeout, client_ip, client_port, uuid
                    )
                except Exception as e:
                    if isinstance(e, ttypes.InstanceOverloaded):
                        # Instance we ran the call on was overloaded. We can retry
                        # the command again, hopefully on a different instance
                        self.incrementCounter("bulk_run.remote.overload_error")
                        self.logger.error("Instance Overloaded: %d: %s", retry_count, e)

                    if (
                        self._remote_task_should_retry(ex=e)
                        and retry_count <= self.BULK_RETRY_LIMIT
                    ):
                        # Stagger the retries
                        delay = random.uniform(
                            self.BULK_RETRY_DELAY_MIN, self.BULK_RETRY_DELAY_MAX
                        )
                        await asyncio.sleep(delay)
                        retry_count += 1
                    else:
                        # Append message as new arg instead of constructing new exception
                        # to account for exceptions having different required args
                        e.args = e.args + ("bulk_run_remote failed",)
                        return self._bulk_failure(chunk, str(e))

        # Split the request into chunks and run them on remote hosts
        tasks = [
            _remote_task(chunk)
            for chunk in self._chunked_dict(device_to_commands, self.LB_THRESHOLD)
        ]

        all_results = {}
        for task in asyncio.as_completed(tasks, loop=self.loop):
            result = await task
            all_results.update(result)

        return all_results

    @ensure_thrift_exception
    @_ensure_uuid
    async def bulk_run_local(
        self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
    ):
        return await self._bulk_run_local(
            device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
        )

    @_ensure_uuid
    async def _bulk_run_local(
        self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
    ) -> typing.Dict[str, typing.List[ttypes.CommandResult]]:
        devices = sorted(device_to_commands.keys(), key=lambda d: d.hostname)

        session_count = self._bulk_session_count
        if session_count + len(device_to_commands) > self.BULK_SESSION_LIMIT:
            self.logger.error("Too many session open: %d", session_count)
            raise ttypes.InstanceOverloaded(
                message="Too many session open: %d" % session_count
            )

        self._set_bulk_session_count(self._bulk_session_count + len(devices))

        async def _run_one_device(device) -> DeviceResult:
            if not device_to_commands[device]:
                return DeviceResult(device_response=[])
            # Instead of running all commands at once, stagger the commands to
            # distribute the load
            delay = random.uniform(0, self.BULK_RUN_JITTER)
            await asyncio.sleep(delay)
            return await self._run_commands(
                device_to_commands[device],
                device,
                timeout,
                open_timeout,
                client_ip,
                client_port,
                uuid,
                return_exceptions=True,
            )

        captured_time_ms_list = []
        try:
            commands = []
            for device in devices:
                commands.append(_run_one_device(device))

            # Run commands in parallel
            results: typing.List[DeviceResult] = await asyncio.gather(
                *commands, loop=self.loop, return_exceptions=True
            )

            # pyre-ignore: _run_commands always returns list of CommandResults
            # when return_exceptions is True
            cmd_results: typing.List[typing.List[ttypes.CommandResult]] = [
                result.device_response for result in results
            ]

            # Get the captured communication/processing time to save later in finally block
            captured_time_ms_list = [result.captured_time_ms for result in results]
        finally:
            self._set_bulk_session_count(self._bulk_session_count - len(devices))

            GlobalNamespace.set_api_captured_time_ms(
                functools.reduce(
                    CapturedTimeMS.__add__,
                    [GlobalNamespace.get_api_captured_time_ms()]
                    + captured_time_ms_list,
                )
            )

        return {
            self._get_result_key(dev): res for dev, res in zip(devices, cmd_results)
        }

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def open_session(
        self, device, open_timeout, idle_timeout, client_ip, client_port, uuid
    ):
        return await self._open_session(
            device,
            open_timeout,
            idle_timeout,
            client_ip,
            client_port,
            uuid,
            raw_session=False,
        )

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def run_session(
        self, session, command, timeout, client_ip, client_port, uuid
    ):
        return await self._run_session(
            session, command, timeout, client_ip, client_port, uuid
        )

    @ensure_thrift_exception
    @input_fields_validator
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def close_session(self, session, client_ip, client_port, uuid):
        closed_session = None
        try:
            closed_session = CommandSession.get(session.id, client_ip, client_port)
            # Reset captured time field so we don't include
            # the time from a previous API call
            closed_session.captured_time_ms.reset_time()
            await closed_session.close()
        except Exception as e:
            # Append message as new arg instead of constructing new exception
            # to account for exceptions having different required args
            e.args = e.args + ("close_session failed",)
            raise e
        finally:
            if closed_session:
                GlobalNamespace.set_api_captured_time_ms(
                    GlobalNamespace.get_api_captured_time_ms()
                    + closed_session.captured_time_ms
                )

    @ensure_thrift_exception
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def open_raw_session(
        self, device, open_timeout, idle_timeout, client_ip, client_port, uuid
    ):
        return await self._open_session(
            device,
            open_timeout,
            idle_timeout,
            client_ip,
            client_port,
            uuid,
            raw_session=True,
        )

    @ensure_thrift_exception
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def run_raw_session(
        self, tsession, command, timeout, prompt_regex, client_ip, client_port, uuid
    ):
        if not prompt_regex:
            raise ttypes.SessionException(message="prompt_regex not specified")

        prompt_re = re.compile(prompt_regex.encode("utf8"), re.M)

        return await self._run_session(
            tsession, command, timeout, client_ip, client_port, uuid, prompt_re
        )

    @ensure_thrift_exception
    @_append_debug_info_to_exception
    @_ensure_uuid
    async def close_raw_session(self, tsession, client_ip, client_port, uuid):
        return await self.close_session(tsession, client_ip, client_port, uuid)

    async def _open_session(
        self,
        device,
        open_timeout,
        idle_timeout,
        client_ip,
        client_port,
        uuid,
        raw_session=False,
    ):
        options = self._get_options(
            device,
            client_ip,
            client_port,
            open_timeout,
            idle_timeout,
            raw_session=raw_session,
        )

        session = None
        try:
            devinfo = await self._lookup_device(device)
            session = await devinfo.setup_session(
                self.service, device, options, loop=self.loop
            )

            return ttypes.Session(
                id=session.id, name=session.hostname, hostname=device.hostname
            )
        except Exception as e:
            # Append message as new arg instead of constructing new exception
            # to account for exceptions having different required args
            e.args = e.args + ("open_session failed",)
            raise e
        finally:
            if session:
                GlobalNamespace.set_api_captured_time_ms(
                    GlobalNamespace.get_api_captured_time_ms()
                    + session.captured_time_ms
                )

    async def _run_session(
        self, tsession, command, timeout, client_ip, client_port, uuid, prompt_re=None
    ):
        session = None
        captured_time_ms = CapturedTimeMS()
        try:
            session = CommandSession.get(tsession.id, client_ip, client_port)
            # Reset captured time field so we don't include
            # the time from a previous API call
            session.captured_time_ms.reset_time()
            return await self._run_command(session, command, timeout, uuid, prompt_re)
        except Exception as e:
            # Append message as new arg instead of constructing new exception
            # to account for exceptions having different required args
            e.args = e.args + ("run_session failed",)
            raise e
        finally:
            if session:
                captured_time_ms = session.captured_time_ms

            GlobalNamespace.set_api_captured_time_ms(
                GlobalNamespace.get_api_captured_time_ms() + captured_time_ms
            )

    def _get_result_key(self, device):
        # TODO: just returning the hostname for now. Some additional processing
        # may be required e.g. using shortnames, adding console info, etc
        return device.hostname

    async def _run_command(self, session, command, timeout, uuid, prompt_re=None):
        self.logger.info(f"[request_id={uuid}]: Run command with session {session.id}")
        output = await session.run_command(
            command.encode("utf8"), timeout=timeout, prompt_re=prompt_re
        )
        return session.build_result(
            output=output.decode("utf8", errors="ignore"),
            status=session.exit_status or constants.SUCCESS_STATUS,
            command=command,
        )

    async def _run_commands(
        self,
        commands,
        device,
        timeout,
        open_timeout,
        client_ip,
        client_port,
        uuid,
        return_exceptions=False,
    ) -> DeviceResult:

        options = self._get_options(
            device, client_ip, client_port, open_timeout, timeout
        )

        if device.command_prompts:
            options["command_prompts"] = {
                c.encode(): p.encode() for c, p in device.command_prompts.items()
            }

        command = commands[0]
        devinfo = None
        session = None
        results = []
        captured_time_ms = CapturedTimeMS()
        exc = None
        try:
            devinfo = await self._lookup_device(device)

            async with devinfo.create_session(
                self.service, device, options, loop=self.loop
            ) as session:

                for command in commands:
                    result = await self._run_command(session, command, timeout, uuid)
                    results.append(result)

        except Exception as e:
            await self._record_error(e, command, uuid, options, devinfo, session)
            if return_exceptions:
                if not isinstance(e, ttypes.SessionException):
                    e = convert_to_fcr_exception(e)
                    e = ttypes.SessionException(message=f"{e!s}", code=e._CODE)

                e.message = await self.add_debug_info_to_error_message(  # noqa
                    error_msg=e.message, uuid=uuid  # noqa
                )
                results = [
                    ttypes.CommandResult(output="", status="%r" % e, command=command)
                ]
            else:
                # save exc from the original place so we have full stacktrace
                # when we raise in API level later
                exc = e

        if session:
            captured_time_ms = session.captured_time_ms

        return DeviceResult(
            device_response=results,
            captured_time_ms=captured_time_ms,
            exception=exc,
        )

    def _chunked_dict(self, data, chunk_size):
        """split the dict into smaller dicts"""
        items = iter(data.items())  # get an iterator for items
        for _ in range(0, len(data), chunk_size):
            yield dict(islice(items, chunk_size))

    async def _bulk_run_remote(
        self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
    ):

        # Determine a timeout for remote call.
        call_timeout = open_timeout + timeout
        remote_timeout = timeout - self.REMOTE_CALL_OVERHEAD

        # Make sure we have a sane timeout value
        assert remote_timeout > 10, "timeout: '%d' value too low for bulk_run" % timeout

        fcr_client = await self._get_fcr_client(timeout=call_timeout)
        async with fcr_client as client:
            result = await client.bulk_run_local(
                device_to_commands,
                remote_timeout,
                open_timeout,
                client_ip,
                client_port,
                uuid,
            )
            return result

    async def _lookup_device(self, device):
        return await self.service.device_db.get(device)

    async def _get_fcr_client(self, timeout):
        return await self.service.get_fcr_client(timeout=timeout)

    def _get_options(
        self,
        device,
        client_ip,
        client_port,
        open_timeout,
        idle_timeout,
        raw_session=False,
    ):
        options = {
            "username": self._get_device_username(device),
            "password": self._get_device_password(device),
            "console": device.console,
            "command_prompts": {},
            "client_ip": client_ip,
            "client_port": client_port,
            "mgmt_ip": device.mgmt_ip or False,
            "open_timeout": open_timeout,
            "idle_timeout": idle_timeout,
            "ip_address": device.ip_address,
            "session_type": device.session_type,
            "device": device,
            "raw_session": raw_session,
            "clear_command": device.clear_command,
        }

        if device.command_prompts:
            options["command_prompts"] = {
                c.encode(): p.encode() for c, p in device.command_prompts.items()
            }

        return options

    def _get_device_username(self, device):
        return device.username

    def _get_device_password(self, device):
        # If the username is specified then password must also be specified.
        if device.username:
            return self._decrypt(device.password)

    def _decrypt(self, data):
        return data

    async def _record_error(
        self, error, command, uuid, options, devinfo, session, **kwargs
    ):
        """
        Subclass can override this method to export the interested error messages
        to proper data store
        """
        pass

    def _remote_task_should_retry(self, ex: Exception) -> bool:
        """
        The function that decides whether a remote task (fan out request) should or should not
        retry
        """

        if not ex:
            return False

        if isinstance(ex, ttypes.InstanceOverloaded):
            return True

        return False
