# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------

from typing import TYPE_CHECKING, Iterable, Optional, Dict

from azure.cli.core.azclierror import (
    AzureResponseError,
    FileOperationError,
    ForbiddenError,
    InvalidArgumentValueError,
    ValidationError,
)
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from knack.log import get_logger
from rich.console import Console

from ....util.az_client import (
    get_registry_mgmt_client,
    get_storage_mgmt_client,
    parse_resource_id,
    wait_for_terminal_state,
)
from ....util.common import should_continue_prompt
from ....util.queryable import Queryable
from ..common import SchemaFormat, SchemaType
from ..permissions import ROLE_DEF_FORMAT_STR, PermissionManager, PrincipalType

logger = get_logger(__name__)
console = Console()
if TYPE_CHECKING:
    from ....vendor.clients.deviceregistrymgmt.operations import (
        SchemaRegistriesOperations,
        SchemasOperations,
        SchemaVersionsOperations,
    )


STORAGE_BLOB_DATA_CONTRIBUTOR_ROLE_ID = "ba92f5b4-2d11-453d-a403-e96b0029c9fe"


def get_user_msg_warn_ra(prefix: str, principal_id: str, scope: str):
    return (
        f"{prefix}\n\n"
        f"The schema registry MSI principal '{principal_id}' needs\n"
        "'Storage Blob Data Contributor' or equivalent role against scope:\n"
        f"'{scope}'\n\n"
        "Please handle this step before continuing."
    )


class SchemaRegistries(Queryable):
    def __init__(self, cmd):
        super().__init__(cmd=cmd)
        self.registry_mgmt_client = get_registry_mgmt_client(
            subscription_id=self.default_subscription_id,
        )
        self.ops: "SchemaRegistriesOperations" = self.registry_mgmt_client.schema_registries

    def create(
        self,
        name: str,
        resource_group_name: str,
        registry_namespace: str,
        storage_account_resource_id: str,
        storage_container_name: str,
        location: Optional[str] = None,
        description: Optional[str] = None,
        display_name: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None,
        custom_role_id: Optional[str] = None,
        **kwargs,
    ) -> dict:
        from ..rp_namespace import ADR_PROVIDER, register_providers
        with console.status("Working...") as c:
            # Register the schema (ADR) provider
            register_providers(self.default_subscription_id, ADR_PROVIDER)

            if not location:
                location = self.get_resource_group(name=resource_group_name)["location"]

            storage_id_container = parse_resource_id(storage_account_resource_id)

            self.storage_mgmt_client = get_storage_mgmt_client(
                subscription_id=storage_id_container.subscription_id,
            )
            storage_account: dict = self.storage_mgmt_client.storage_accounts.get_properties(
                resource_group_name=storage_id_container.resource_group_name,
                account_name=storage_id_container.resource_name,
            )
            storage_properties: dict = storage_account["properties"]
            is_hns_enabled = storage_properties.get("isHnsEnabled", False)
            if not is_hns_enabled:
                raise ValidationError(
                    "Schema registry requires the storage account to have hierarchical namespace enabled."
                )

            try:
                blob_container = self.storage_mgmt_client.blob_containers.get(
                    resource_group_name=storage_id_container.resource_group_name,
                    account_name=storage_id_container.resource_name,
                    container_name=storage_container_name,
                )
            except ResourceNotFoundError:
                blob_container = self.storage_mgmt_client.blob_containers.create(
                    resource_group_name=storage_id_container.resource_group_name,
                    account_name=storage_id_container.resource_name,
                    container_name=storage_container_name,
                    blob_container={},
                )

            blob_container_url = f"{storage_properties['primaryEndpoints']['blob']}{blob_container['name']}"
            resource = {
                "location": location,
                "identity": {
                    "type": "SystemAssigned",
                },
                "properties": {
                    "namespace": registry_namespace,
                    "storageAccountContainerUrl": blob_container_url,
                    "description": description,
                    "displayName": display_name,
                },
            }
            if tags:
                resource["tags"] = tags

            poller = self.ops.begin_create_or_replace(
                resource_group_name=resource_group_name, schema_registry_name=name, resource=resource
            )
            result = wait_for_terminal_state(poller, **kwargs)

            target_role_def = custom_role_id or ROLE_DEF_FORMAT_STR.format(
                subscription_id=storage_id_container.subscription_id, role_id=STORAGE_BLOB_DATA_CONTRIBUTOR_ROLE_ID
            )
            permission_manager = PermissionManager(storage_id_container.subscription_id)
            try:
                permission_manager.apply_role_assignment(
                    scope=blob_container["id"],
                    principal_id=result["identity"]["principalId"],
                    role_def_id=target_role_def,
                    principal_type=PrincipalType.SERVICE_PRINCIPAL.value
                )
            except Exception as e:
                c.stop()
                raise AzureResponseError(
                    get_user_msg_warn_ra(
                        prefix=f"Role assignment failed with:\n{str(e)}",
                        principal_id=result["identity"]["principalId"],
                        scope=blob_container["id"],
                    )
                )

            return result

    def show(self, name: str, resource_group_name: str) -> dict:
        return self.ops.get(resource_group_name=resource_group_name, schema_registry_name=name)

    def list(self, resource_group_name: Optional[str] = None) -> Iterable[dict]:
        if resource_group_name:
            return self.ops.list_by_resource_group(resource_group_name=resource_group_name)
        return self.ops.list_by_subscription()

    def delete(self, name: str, resource_group_name: str, confirm_yes: Optional[bool] = None, **kwargs):
        should_bail = not should_continue_prompt(confirm_yes=confirm_yes)
        if should_bail:
            return

        with console.status("Working..."):
            try:
                poller = self.ops.begin_delete(resource_group_name=resource_group_name, schema_registry_name=name)
                wait_for_terminal_state(poller, **kwargs)
            except HttpResponseError as e:
                if e.status_code != 200:
                    raise e


class Schemas(Queryable):
    def __init__(self, cmd):
        super().__init__(cmd=cmd)
        self.registry_mgmt_client = get_registry_mgmt_client(
            subscription_id=self.default_subscription_id,
        )
        self.ops: "SchemasOperations" = self.registry_mgmt_client.schemas
        self.version_ops: "SchemaVersionsOperations" = self.registry_mgmt_client.schema_versions

    def create(
        self,
        name: str,
        schema_registry_name: str,
        resource_group_name: str,
        schema_type: str,
        schema_format: str,
        schema_version_content: str,
        schema_version: int = 1,
        description: Optional[str] = None,
        display_name: Optional[str] = None,
        schema_version_description: Optional[str] = None
    ) -> dict:
        with console.status("Working...") as c:
            schema_type = SchemaType[schema_type].full_value
            schema_format = SchemaFormat[schema_format].full_value
            resource = {
                "properties": {
                    "format": schema_format,
                    "schemaType": schema_type,
                    "description": description,
                    "displayName": display_name,
                },
            }
            schema = self.ops.create_or_replace(
                resource_group_name=resource_group_name,
                schema_registry_name=schema_registry_name,
                schema_name=name,
                resource=resource
            )
            logger.info(f"Created schema {name}.")
            # TODO: maybe add in an exception catch for auth errors
            self.add_version(
                name=schema_version,
                schema_version_content=schema_version_content,
                schema_name=name,
                schema_registry_name=schema_registry_name,
                resource_group_name=resource_group_name,
                description=schema_version_description,
                current_console=c
            )
            logger.info(f"Added version {schema_version} to schema {name}.")
            return schema

    def show(self, name: str, schema_registry_name: str, resource_group_name: str) -> dict:
        return self.ops.get(
            resource_group_name=resource_group_name,
            schema_registry_name=schema_registry_name,
            schema_name=name
        )

    def list(self, schema_registry_name: str, resource_group_name: str) -> Iterable[dict]:
        return self.ops.list_by_schema_registry(
            resource_group_name=resource_group_name, schema_registry_name=schema_registry_name
        )

    def delete(
        self,
        name: str,
        schema_registry_name: str,
        resource_group_name: str,
        confirm_yes: Optional[bool] = None,
    ):
        if not should_continue_prompt(confirm_yes=confirm_yes):
            return

        with console.status("Working..."):
            return self.ops.delete(
                resource_group_name=resource_group_name,
                schema_registry_name=schema_registry_name,
                schema_name=name
            )

    def add_version(
        self,
        name: int,
        schema_name: str,
        schema_registry_name: str,
        resource_group_name: str,
        schema_version_content: str,
        description: Optional[str] = None,
        current_console: Optional[Console] = None,
    ) -> dict:
        from ....util import read_file_content

        if name < 0:
            raise InvalidArgumentValueError("Version must be a positive number")

        try:
            logger.debug("Processing schema content.")
            schema_version_content = read_file_content(schema_version_content)
        except FileOperationError:
            logger.debug("Given schema content is not a file.")
            pass

        resource = {
            "properties": {
                "schemaContent": schema_version_content,
                "description": description,
            },
        }
        try:
            with current_console or console.status("Working..."):
                return self.version_ops.create_or_replace(
                    resource_group_name=resource_group_name,
                    schema_registry_name=schema_registry_name,
                    schema_name=schema_name,
                    schema_version_name=name,
                    resource=resource
                )
        except HttpResponseError as e:
            if e.status_code == 412:
                raise ForbiddenError(
                    "Schema versions require public network access to be enabled in the associated storage account."
                )
            raise e

    def show_version(
        self,
        name: int,
        schema_name: str,
        schema_registry_name: str,
        resource_group_name: str,
    ) -> dict:
        # service verifies hash during create already
        return self.version_ops.get(
            resource_group_name=resource_group_name,
            schema_registry_name=schema_registry_name,
            schema_name=schema_name,
            schema_version_name=name,
        )

    def list_versions(
        self, schema_name: str, schema_registry_name: str, resource_group_name: str
    ) -> Iterable[dict]:
        return self.version_ops.list_by_schema(
            resource_group_name=resource_group_name,
            schema_registry_name=schema_registry_name,
            schema_name=schema_name
        )

    def remove_version(
        self,
        name: int,
        schema_name: str,
        schema_registry_name: str,
        resource_group_name: str,
    ):
        with console.status("Working..."):
            return self.version_ops.delete(
                resource_group_name=resource_group_name,
                schema_registry_name=schema_registry_name,
                schema_name=schema_name,
                schema_version_name=name,
            )

    def list_dataflow_friendly_versions(
        self,
        schema_registry_name: str,
        resource_group_name: str,
        schema_name: Optional[str] = None,
        schema_version: Optional[int] = None,
        latest: bool = False
    ) -> dict:
        from collections import OrderedDict
        # note temporary until dataflow create is added.
        versions_map = {}
        with console.status("Fetching version info..."):
            # get all the versions first
            if schema_name and schema_version:
                versions_map[schema_name] = [int(schema_version)]
            elif schema_name:
                versions_map.update(
                    self._get_schema_version_dict(
                        schema_name=schema_name,
                        schema_registry_name=schema_registry_name,
                        resource_group_name=resource_group_name,
                        latest=latest
                    )
                )
            elif schema_version:
                # TODO: maybe do the weird
                raise InvalidArgumentValueError(
                    "Please provide the schema name if schema versions is used."
                )
            else:
                schema_list = self.list(
                    schema_registry_name=schema_registry_name, resource_group_name=resource_group_name
                )
                for schema in schema_list:
                    versions_map.update(
                        self._get_schema_version_dict(
                            schema_name=schema["name"],
                            schema_registry_name=schema_registry_name,
                            resource_group_name=resource_group_name,
                            latest=latest
                        )
                    )

            ref_format = "aio-sr://{schema}:{version}"
            # change to ordered dict for order, azure cli does not like the int keys at that level
            for schema_name, versions_list in versions_map.items():
                ordered = OrderedDict(
                    (str(ver), ref_format.format(schema=schema_name, version=ver)) for ver in versions_list
                )
                versions_map[schema_name] = ordered

            return versions_map

    def _get_schema_version_dict(
        self, schema_name: str, schema_registry_name: str, resource_group_name: str, latest: bool = False
    ) -> dict:
        version_list = self.list_versions(
            schema_name=schema_name,
            schema_registry_name=schema_registry_name,
            resource_group_name=resource_group_name
        )
        version_list = [int(ver["name"]) for ver in version_list]
        if latest:
            version_list = [max(version_list)]
        return {schema_name: sorted(version_list)}
