#  Copyright 2023 Alibaba, Inc. or its affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import locale
import os
import re
from typing import Any, Dict, List, Optional, Tuple

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_sts20150401.client import Client
from alibabacloud_sts20150401.models import (
    GetCallerIdentityResponseBody as CallerIdentity,
)
from alibabacloud_tea_openapi import models as open_api_models
from oss2.models import BucketInfo, SimplifiedBucketInfo
from prompt_toolkit import Application
from prompt_toolkit.key_binding import KeyBindings, merge_key_bindings
from prompt_toolkit.key_binding.defaults import load_key_bindings
from prompt_toolkit.layout import HSplit, Layout
from prompt_toolkit.shortcuts import confirm as prompt_confirm
from prompt_toolkit.widgets import Label, RadioList

from ...api.base import ServiceName
from ...api.client_factory import ClientFactory
from ...api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ...common.logging import get_logger
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import is_domain_connectable, make_list_resource_iterator
from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient
from ...libs.alibabacloud_pai_dsw20220101.models import GetInstanceRequest
from ...session import Session

logger = get_logger(__name__)

locale_code, _ = locale.getdefaultlocale()

OSS_NAME_PATTERN = re.compile(pattern="^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$")
ZH_CN_LOCAL = "zh_CN"

# RoleARN pattern for AssumedRole CallerIdentity
ASSUMED_ROLE_ARN_PATTERN = re.compile(r"acs:ram::\d+:assumed-role/([^/]+)/.*")

# DSW Notebook Default Role Name:
PAI_DSW_DEFAULT_ROLE_NAME = "aliyunpaidswdefaultrole"


DEFAULT_PRODUCT_RAM_ROLE_NAMES = [
    "AliyunODPSPAIDefaultRole",
    "AliyunPAIAccessingOSSRole",
    "AliyunPAIDLCAccessingOSSRole",
    "AliyunPAIDLCDefaultRole",
]


class WorkspaceRoles(object):
    """Workspace roles."""

    AlgoDeveloper = "PAI.AlgoDeveloper"
    WorkspaceAdmin = "PAI.WorkspaceAdmin"
    WorkspaceOwner = "PAI.WorkspaceOwner"
    LabelManager = "PAI.LabelManager"

    @classmethod
    def recommend_roles(cls):
        """Recommend roles for user to use."""
        return [
            cls.AlgoDeveloper,
            cls.WorkspaceAdmin,
            cls.WorkspaceOwner,
        ]


class CallerIdentityType(object):

    # Document: https://help.aliyun.com/document_detail/371868.html

    # - Account: an Alibaba Cloud account
    Account = "Account"
    # - RamUser: a RAM user
    RamUser = "RAMUser"
    # - AssumedRoleUser: a RAM role
    AssumedRoleUser = "AssumedRoleUser"


class UserProfile(object):

    _credential_client = None

    def __init__(
        self,
        credential_config: CredentialConfig,
        region_id: str,
    ):
        self.region_id = region_id
        self.credential_config = credential_config

        if DEFAULT_NETWORK_TYPE:
            self.network = Network.from_string(DEFAULT_NETWORK_TYPE)
        else:
            self.network = (
                Network.VPC
                if is_domain_connectable(PAI_VPC_ENDPOINT.format(self.region_id))
                else Network.PUBLIC
            )
        self._caller_identify = self._get_caller_identity()

    def _get_credential_client(self):
        if self._credential_client:
            return self._credential_client
        self._credential_client = CredentialClient(self.credential_config)
        return self._credential_client

    def get_access_key_id(self):
        return self._get_credential_client().get_access_key_id()

    def get_access_key_secret(self):
        return self._get_credential_client().get_access_key_secret()

    def get_security_token(self):
        return self._get_credential_client().get_security_token()

    def _get_caller_identity(self) -> CallerIdentity:
        return (
            Client(
                config=open_api_models.Config(
                    credential=self._get_credential_client(),
                    region_id=self.region_id,
                    network=(
                        None
                        if self.network == Network.PUBLIC
                        else self.network.value.lower()
                    ),
                )
            )
            .get_caller_identity()
            .body
        )

    def is_dsw_default_role(self) -> bool:
        if self._caller_identify.identity_type != CallerIdentityType.AssumedRoleUser:
            return False
        m = ASSUMED_ROLE_ARN_PATTERN.match(self._caller_identify.arn)
        return m and m.group(1).lower() == PAI_DSW_DEFAULT_ROLE_NAME

    def get_acs_dsw_client(self) -> DswClient:
        return ClientFactory.create_client(
            service_name=ServiceName.PAI_DSW,
            credential_client=self._get_credential_client(),
            region_id=self.region_id,
            network=self.network,
        )

    def get_instance_info(self, instance_id: str) -> Dict[str, Any]:
        dsw_client = self.get_acs_dsw_client()
        return dsw_client.get_instance(
            instance_id, request=GetInstanceRequest()
        ).body.to_map()

    def get_credential(self):
        return self._credential_client.get_access_key_id()

    @property
    def is_ram_user(self) -> bool:
        return self._caller_identify.identity_type == CallerIdentityType.RamUser

    @property
    def is_account(self) -> bool:
        return self._caller_identify.identity_type == CallerIdentityType.Account

    @property
    def account_id(self):
        """Return Alibaba Cloud account ID of the current user profile"""
        return self._caller_identify.account_id

    @property
    def user_id(self):
        """Return the Alibaba Cloud user ID of the current user profile"""
        return self._caller_identify.user_id

    @property
    def identify_type(self):
        return self._caller_identify.identity_type

    def get_default_oss_endpoint(self):
        return "https://oss-{}.aliyuncs.com".format(self.region_id)

    def list_oss_buckets(self, prefix: str = "") -> List[SimplifiedBucketInfo]:
        buckets: List[SimplifiedBucketInfo] = []
        service = oss2.Service(
            auth=oss2.ProviderAuth(
                credentials_provider=CredentialProviderWrapper(
                    config=self.credential_config,
                ),
            ),
            endpoint=self.get_default_oss_endpoint(),
        )

        marker = ""
        while True:
            res: oss2.models.ListBucketsResult = service.list_buckets(
                prefix=prefix, marker=marker
            )
            buckets.extend(
                [b for b in res.buckets if self.region_id in b.location] or []
            )
            if not res.is_truncated:
                break
            else:
                marker = res.next_marker

        return buckets

    def get_bucket_info(self, bucket_name) -> BucketInfo:
        auth = oss2.ProviderAuth(
            credentials_provider=CredentialProviderWrapper(
                config=self.credential_config,
            ),
        )
        bucket = oss2.Bucket(
            auth, self.get_default_oss_endpoint(), bucket_name=bucket_name
        )
        bucket_info = bucket.get_bucket_info()
        return bucket_info

    def create_oss_bucket(self, bucket_name):
        bucket = oss2.Bucket(
            bucket_name=bucket_name,
            auth=oss2.ProviderAuth(
                credentials_provider=CredentialProviderWrapper(
                    config=self.credential_config,
                ),
            ),
            endpoint=self.get_default_oss_endpoint(),
        )
        bucket.create_bucket()

    def get_production_authorizations(self):
        workspace_api = self.get_workspace_api()
        res = workspace_api.list_product_authorizations(
            ram_role_names=DEFAULT_PRODUCT_RAM_ROLE_NAMES
        )
        return res["AuthorizationDetails"]

    def get_workspace_api(self) -> WorkspaceAPI:
        acs_ws_client = ClientFactory.create_client(
            service_name=ServiceName.PAI_WORKSPACE,
            credential_client=self._get_credential_client(),
            region_id=self.region_id,
            network=self.network,
        )

        return WorkspaceAPI(
            acs_client=acs_ws_client,
        )

    def get_default_oss_storage_uri(
        self, workspace_id: str
    ) -> Tuple[Optional[str], Optional[str]]:
        bucket_name, endpoint = Session.get_default_oss_storage(
            workspace_id=workspace_id,
            cred=self._get_credential_client(),
            region_id=self.region_id,
            network=self.network,
        )
        return "oss://{}/".format(bucket_name), endpoint

    def set_default_oss_storage(
        self, workspace_id, bucket_name: str, intranet_endpoint: str
    ):
        workspace_api = self.get_workspace_api()
        oss_uri = "oss://{}.{}/".format(bucket_name, intranet_endpoint)
        configs = {WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI: oss_uri}
        workspace_api.update_configs(workspace_id, configs=configs)

    def get_roles_in_workspace(
        self, workspace_id, user_id: Optional[str] = None
    ) -> List[str]:
        workspace_api = self.get_workspace_api()
        user_id = user_id or self.user_id
        member_info = next(
            (
                mem
                for mem in make_list_resource_iterator(
                    workspace_api.list_members,
                    workspace_id=workspace_id,
                )
                if mem["UserId"] == user_id
            ),
            None,
        )

        # If user has PAIFullAccess policy, 'member_info' may be None.
        return member_info["Roles"] if member_info else []

    def has_permission_edit_config(self, workspace_id: str) -> bool:
        """Return True if the current user has permission to edit workspace config.

        Only members with the role of WorkspaceAdmin or WorkspaceOwner can edit
        workspace config.

        """
        roles = self.get_roles_in_workspace(workspace_id)
        return any(
            (
                r in roles
                for r in [WorkspaceRoles.WorkspaceAdmin, WorkspaceRoles.WorkspaceOwner]
            )
        )


def localized_text(en_text: str, cn_text: str = None):
    if locale_code == ZH_CN_LOCAL:
        return cn_text or en_text
    else:
        return en_text


def mask_secret(secret, mask_count=4):
    masked = (
        secret[:mask_count]
        + (len(secret) - mask_count * 2) * "*"
        + secret[-mask_count:]
    )
    return masked


def mask_and_trim(secret, max_size=20):
    masked = mask_secret(secret, mask_count=8)
    if len(masked) > max_size:
        masked = masked[:max_size] + "..."
    return masked


def radio_list_prompt(
    title: str = "",
    values=None,
    cancel_value=None,
    style=None,
    async_: bool = False,
    **kwargs,
):
    # Create the radio list
    radio_list = RadioList(values)
    # Remove the enter key binding so that we can augment it
    radio_list.control.key_bindings.remove("enter")

    bindings = KeyBindings()

    # Replace the enter key binding to select the value and also submit it
    @bindings.add("enter")
    def exit_with_value(event):
        """
        Pressing Enter will exit the user interface, returning the highlighted value.
        """
        radio_list._handle_enter()
        event.app.exit(result=radio_list.current_value)

    @bindings.add("c-c")
    def backup_exit_with_value(event):
        """
        Pressing Ctrl-C will exit the user interface with the cancel_value.
        """
        event.app.exit(result=cancel_value)

    # Create and run the mini inline application
    application = Application(
        layout=Layout(HSplit([Label(title), radio_list])),
        key_bindings=merge_key_bindings([load_key_bindings(), bindings]),
        mouse_support=True,
        style=style,
        **kwargs,
    )
    if async_:
        return application.run_async()
    else:
        return application.run()


def confirm(message: str = "Confirm?", suffix: str = " (y/n, default: y)"):
    # Input enter key returns an empty string, we assume enter is 'YES'.
    res = prompt_confirm(message, suffix)
    yes = True if isinstance(res, str) and res.strip() == "" else res
    return yes


def not_empty(text: str) -> bool:
    return bool(text.strip())


def print_highlight(msg: str):
    print(ColorEscape.green(msg))


def print_warning(msg: str):
    print(ColorEscape.red(msg))


def validate_bucket_name(name: str) -> bool:
    return bool(OSS_NAME_PATTERN.match(name))


class ColorEscape(object):
    """
    A utility class to wrap a string with color escape code.
    """

    _black = "\u001b[30m"
    _red = "\u001b[31m"
    _green = "\u001b[32m"
    _yellow = "\u001b[33m"
    _blue = "\u001b[34m"
    _magenta = "\u001b[35m"
    _cyan = "\u001b[36m"
    _white = "\u001b[37m"
    _default = "\u001b[39m"

    _reset = "\u001b[0m"

    @classmethod
    def green(cls, msg: str) -> str:
        return cls._format(msg, cls._green)

    @classmethod
    def red(cls, msg: str) -> str:
        return cls._format(msg, cls._red)

    @classmethod
    def _format(cls, msg: str, code: str) -> str:
        if os.environ.get("NO_COLOR"):
            # See https://no-color.org/
            return msg
        return code + msg + cls._reset
