# coding=utf-8
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from typing import List
from azure.cli.core.azclierror import (
    AzureResponseError,
    ClientRequestError,
    CLIInternalError,
    RequiredArgumentMissingError,
    ResourceNotFoundError,
)
from knack.log import get_logger
from azext_iot.central.models.devicetwin import DeviceTwin
from azext_iot.central.models.edge import EdgeModule
from azext_iot.constants import CENTRAL_ENDPOINT
from azext_iot.central import services as central_services
from azext_iot.central.models.enum import DeviceStatus, ApiVersion
from azext_iot.central.models.ga_2022_07_31 import (DeviceGa, RelationshipGa)
from azext_iot.dps.services import global_service as dps_global_service


logger = get_logger(__name__)
MODEL = "Device"


class CentralDeviceProvider:
    def __init__(self, cmd, app_id: str, api_version: str, token=None):
        """
        Provider for device APIs

        Args:
            cmd: command passed into az
            app_id: name of app (used for forming request URL)
            api_version: API version (appendend to request URL)
            token: (OPTIONAL) authorization token to fetch device details from IoTC.
                MUST INCLUDE type (e.g. 'SharedAccessToken ...', 'Bearer ...')
                Useful in scenarios where user doesn't own the app
                therefore AAD token won't work, but a SAS token generated by owner will
        """
        self._cmd = cmd
        self._app_id = app_id
        self._api_version = api_version
        self._token = token
        self._devices = {}
        self._device_templates = {}
        self._device_credentials = {}
        self._device_registration_info = {}

    def get_device(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> DeviceGa:

        # get or add to cache
        device = self._devices.get(device_id)
        if not device:
            device = central_services.device.get_device(
                cmd=self._cmd,
                app_id=self._app_id,
                device_id=device_id,
                token=self._token,
                central_dns_suffix=central_dns_suffix,
                api_version=self._api_version,
            )
            self._devices[device_id] = device

        if not device:
            raise ResourceNotFoundError(
                "No device found with id: '{}'.".format(device_id)
            )

        return self._devices[device_id]

    def list_devices(
        self,
        filter=None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> List[DeviceGa]:
        devices = central_services.device.list_devices(
            cmd=self._cmd,
            app_id=self._app_id,
            token=self._token,
            filter=filter,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        # add to cache
        self._devices.update({device.id: device for device in devices})

        return devices

    def create_device(
        self,
        device_id,
        device_name=None,
        template=None,
        simulated=False,
        organizations=None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> DeviceGa:
        if not device_id:
            raise RequiredArgumentMissingError("Device id must be specified.")

        if device_id in self._devices:
            raise ClientRequestError("Device already exists.")

        device = central_services.device.create_device(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            device_name=device_name,
            template=template,
            simulated=simulated,
            organizations=organizations,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        if not device:
            raise AzureResponseError(
                "Failed to create device with id: '{}'.".format(device_id)
            )

        # add to cache
        self._devices[device.id] = device

        return self._devices[device.id]

    def update_device(
        self,
        device_id,
        device_name=None,
        template=None,
        simulated=None,
        enabled=None,
        organizations=None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> DeviceGa:
        if not device_id:
            raise RequiredArgumentMissingError("Device id must be specified.")

        if device_id in self._devices:
            raise ClientRequestError("Device already exists.")

        device = central_services.device.update_device(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            device_name=device_name,
            template=template,
            simulated=simulated,
            enabled=enabled,
            organizations=organizations,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        if not device:
            raise ResourceNotFoundError(
                "No device found with id: '{}'.".format(device_id)
            )

        # add to cache
        self._devices[device.id] = device

        return self._devices[device.id]

    def delete_device(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:
        if not device_id:
            raise RequiredArgumentMissingError("Device id must be specified.")

        # get or add to cache
        result = central_services.device.delete_device(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        # remove from cache
        # pop "miss" raises a KeyError if None is not provided
        self._devices.pop(device_id, None)
        self._device_credentials.pop(device_id, None)

        return result

    def list_relationships(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> List[RelationshipGa]:
        relationships = central_services.device.list_relationships(
            self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

        if relationships is None:
            return []

        return relationships

    def add_relationship(
        self,
        device_id,
        target_id,
        rel_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:
        relationship = central_services.device.create_relationship(
            self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            rel_id=rel_id,
            target_id=target_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

        if not relationship:
            raise ResourceNotFoundError(
                "No relationship found with id: '{}'.".format(rel_id)
            )

        return relationship

    def update_relationship(
        self,
        device_id,
        target_id,
        rel_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:
        relationship = central_services.device.update_relationship(
            self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            rel_id=rel_id,
            target_id=target_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

        if not relationship:
            raise ResourceNotFoundError(
                "No relationship found with id: '{}'.".format(rel_id)
            )

        return relationship

    def delete_relationship(
        self,
        device_id,
        rel_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:

        # get or add to cache
        result = central_services.device.delete_relationship(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            rel_id=rel_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        return result

    def get_device_credentials(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:
        credentials = self._device_credentials.get(device_id)

        if not credentials:
            credentials = central_services.device.get_device_credentials(
                cmd=self._cmd,
                app_id=self._app_id,
                device_id=device_id,
                token=self._token,
                central_dns_suffix=central_dns_suffix,
                api_version=self._api_version,
            )

        if not credentials:
            raise CLIInternalError(
                "Could not find device credentials for device '{}'.".format(device_id)
            )

        # add to cache
        self._device_credentials[device_id] = credentials

        return credentials

    def get_device_registration_info(
        self,
        device_id,
        device_status: DeviceStatus,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> dict:
        dps_state = {}
        info = self._device_registration_info.get(device_id)

        if info:
            return info

        device = self.get_device(device_id, central_dns_suffix)
        if device._device_status == DeviceStatus.provisioned:
            credentials = self.get_device_credentials(
                device_id=device_id,
                central_dns_suffix=central_dns_suffix,
            )
            id_scope = credentials["idScope"]
            key = credentials["symmetricKey"]["primaryKey"]
            dps_state = dps_global_service.get_registration_state(
                id_scope=id_scope, key=key, device_id=device_id
            )
        dps_state = self._dps_populate_essential_info(dps_state, device._device_status)

        info = {
            "@device_id": device_id,
            "dps_state": dps_state,
            "device_registration_info": device.get_registration_info(),
        }

        self._device_registration_info[device_id] = info

        return info

    def get_device_registration_summary(self, central_dns_suffix=CENTRAL_ENDPOINT):
        return central_services.device.get_device_registration_summary(
            cmd=self._cmd,
            app_id=self._app_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def run_command(
        self,
        device_id: str,
        interface_id: str,
        component_name: str,
        module_name: str,
        command_name: str,
        payload: dict,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        if interface_id and self._is_interface_id_component(
            device_id=device_id,
            interface_id=interface_id,
            central_dns_suffix=central_dns_suffix,
        ):
            return central_services.device.run_command(
                cmd=self._cmd,
                app_id=self._app_id,
                token=self._token,
                device_id=device_id,
                component_name=interface_id,
                module_name=module_name,
                command_name=command_name,
                payload=payload,
                central_dns_suffix=central_dns_suffix,
                api_version=self._api_version,
            )
        return central_services.device.run_command(
            cmd=self._cmd,
            app_id=self._app_id,
            token=self._token,
            device_id=device_id,
            component_name=component_name,
            module_name=module_name,
            command_name=command_name,
            payload=payload,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

    def get_command_history(
        self,
        device_id: str,
        interface_id: str,
        component_name: str,
        module_name: str,
        command_name: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):

        if interface_id and self._is_interface_id_component(
            device_id=device_id,
            interface_id=interface_id,
            central_dns_suffix=central_dns_suffix,
        ):
            return central_services.device.get_command_history(
                cmd=self._cmd,
                app_id=self._app_id,
                token=self._token,
                device_id=device_id,
                component_name=interface_id,
                module_name=module_name,
                command_name=command_name,
                central_dns_suffix=central_dns_suffix,
                api_version=self._api_version,
            )

        return central_services.device.get_command_history(
            cmd=self._cmd,
            app_id=self._app_id,
            token=self._token,
            device_id=device_id,
            component_name=component_name,
            module_name=module_name,
            command_name=command_name,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

    def get_module_command_history(
        self,
        device_id: str,
        module_name: str,
        component_name: str,
        command_name: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.get_module_command_history(
            cmd=self._cmd,
            app_id=self._app_id,
            token=self._token,
            device_id=device_id,
            module_name=module_name,
            component_name=component_name,
            command_name=command_name,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

    def list_device_modules(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> List[EdgeModule]:

        modules = central_services.device.list_device_modules(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

        if not modules:
            return []

        return modules

    def restart_device_module(
        self,
        device_id,
        module_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> List[EdgeModule]:

        status = central_services.device.restart_device_module(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_id=module_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
        )

        if not status or status != 200:
            raise ResourceNotFoundError(
                "No module found for device {} with id: '{}'.".format(
                    device_id, module_id
                )
            )

        return status

    def get_device_twin(
        self,
        device_id,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> DeviceTwin:

        twin = central_services.device.get_device_twin(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
        )

        if not twin:
            raise ResourceNotFoundError(
                "No twin found for device with id: '{}'.".format(device_id)
            )

        return twin

    def run_manual_failover(
        self,
        device_id: str,
        ttl_minutes: int = None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.run_manual_failover(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            ttl_minutes=ttl_minutes,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
        )

    def run_manual_failback(
        self,
        device_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.run_manual_failback(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
        )

    def purge_c2d_messages(
        self,
        device_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.purge_c2d_messages(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
        )

    def get_device_attestation(
        self,
        device_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.get_device_attestation(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def delete_device_attestation(
        self,
        device_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.delete_device_attestation(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def update_device_attestation(
        self,
        device_id: str,
        payload,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.update_device_attestation(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def create_device_attestation(
        self,
        device_id: str,
        payload,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.create_device_attestation(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def list_modules(
        self,
        device_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.list_modules(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def list_device_components(
        self,
        device_id: str,
        module_name: str = None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.list_device_components(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=module_name,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def get_device_properties(
        self,
        device_id: str,
        component_name: str = None,
        module_name: str = None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.get_device_properties_or_telemetry_value(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=module_name,
            component_name=component_name,
            telemetry_name=None,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def replace_device_properties(
        self,
        device_id: str,
        payload: str,
        component_name: str = None,
        module_name: str = None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.replace_properties(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=module_name,
            component_name=component_name,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def update_device_properties(
        self,
        device_id: str,
        payload: str,
        component_name: str = None,
        module_name: str = None,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.update_properties(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=module_name,
            component_name=component_name,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def get_telemetry_value(
        self,
        device_id: str,
        component_name: str,
        module_name: str,
        telemetry_name: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.get_device_properties_or_telemetry_value(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=module_name,
            component_name=component_name,
            telemetry_name=telemetry_name,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def replace_device_component_properties(
        self,
        device_id: str,
        component_name: str,
        payload: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.replace_properties(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=None,
            component_name=component_name,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def update_device_component_properties(
        self,
        device_id: str,
        component_name: str,
        payload: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ):
        return central_services.device.update_properties(
            cmd=self._cmd,
            app_id=self._app_id,
            device_id=device_id,
            module_name=None,
            component_name=component_name,
            payload=payload,
            token=self._token,
            api_version=self._api_version,
            central_dns_suffix=central_dns_suffix,
        )

    def _dps_populate_essential_info(self, dps_info, device_status: DeviceStatus):
        error = {
            DeviceStatus.provisioned: "None.",
            DeviceStatus.registered: "Device is not yet provisioned.",
            DeviceStatus.blocked: "Device is blocked from connecting to IoT Central application."
            " Unblock the device in IoT Central and retry. Learn more: https://aka.ms/iotcentral-docs-dps-SAS",
            DeviceStatus.unassociated: "Device does not have a valid template associated with it.",
        }

        filtered_dps_info = {
            "status": dps_info.get("status"),
            "error": error.get(device_status),
        }
        return filtered_dps_info

    def _is_interface_id_component(
        self,
        device_id: str,
        interface_id: str,
        central_dns_suffix=CENTRAL_ENDPOINT,
    ) -> bool:

        current_device = self.get_device(device_id, central_dns_suffix)

        template = central_services.device_template.get_device_template(
            cmd=self._cmd,
            app_id=self._app_id,
            device_template_id=current_device.instance_of
            if self._api_version == ApiVersion.preview.value
            else current_device.template,
            token=self._token,
            central_dns_suffix=central_dns_suffix,
            api_version=self._api_version,
        )

        if interface_id in template.components:
            return True

        for module in template.modules:
            if interface_id in module.components:
                return True

        return False
