#  Copyright 2023 Alibaba, Inc. or its affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from __future__ import absolute_import

import json
import os.path
import posixpath
from datetime import datetime
from typing import Any, Dict, Optional, Tuple, Union

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.exceptions import CredentialException
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_credentials.utils import auth_constant
from Tea.exceptions import TeaException

from .api.api_container import ResourceAPIsContainerMixin
from .api.base import ServiceName
from .api.client_factory import ClientFactory
from .api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from .common.consts import DEFAULT_CONFIG_PATH, PAI_VPC_ENDPOINT, Network
from .common.logging import get_logger
from .common.oss_utils import CredentialProviderWrapper, OssUriObj
from .common.utils import is_domain_connectable, make_list_resource_iterator
from .libs.alibabacloud_pai_dsw20220101.models import GetInstanceRequest

logger = get_logger(__name__)

# Environment variable that indicates where the config path is located.
# If it is not provided, "$HOME/.pai/config.json" is used as the default config path.
ENV_PAI_CONFIG_PATH = "PAI_CONFIG_PATH"

INNER_REGION_IDS = ["center"]


# Global default session used by the program.
_default_session = None

# Default config keys.
_DEFAULT_CONFIG_KEYS = [
    "region_id",
    "oss_bucket_name",
    "workspace_id",
    "oss_endpoint",
]


def setup_default_session(
    access_key_id: Optional[str] = None,
    access_key_secret: Optional[str] = None,
    security_token: Optional[str] = None,
    region_id: Optional[str] = None,
    credential_config: Optional[CredentialConfig] = None,
    oss_bucket_name: Optional[str] = None,
    oss_endpoint: Optional[str] = None,
    workspace_id: Optional[Union[str, int]] = None,
    network: Optional[Union[str, Network]] = None,
    **kwargs,
) -> "Session":
    """Set up the default session used in the program.

    The function construct a session that used for communicating with PAI service,
    and set it as the global default instance.

    Args:
        access_key_id (str): The access key ID used to access the Alibaba Cloud.
        access_key_secret (str): The access key secret used to access the Alibaba Cloud.
        security_token (str, optional): The security token used to access the Alibaba
            Cloud.
        credential_config (:class:`alibabacloud_credentials.models.Config`, optional):
            The credential config used to access the Alibaba Cloud.
        region_id (str): The ID of the Alibaba Cloud region where the service
            is located.
        workspace_id (str, optional): ID of the workspace used in the default
            session.
        oss_bucket_name (str, optional): The name of the OSS bucket used in the
            session.
        oss_endpoint (str, optional): The endpoint for the OSS bucket.
        network (Union[str, Network], optional): The network to use for the connection.
            supported values are "VPC" and "PUBLIC". If provided, this value will be used as-is.
            Otherwise, the code will first check for an environment variable PAI_NETWORK_TYPE.
            If that is not set and the VPC endpoint is available, it will be used.
            As a last resort, if all else fails, the PUBLIC endpoint will be used.
        **kwargs:

    Returns:
        :class:`pai.session.Session`: Initialized default session.

    """

    if (access_key_id and access_key_secret) and credential_config:
        raise ValueError("Please provide either access_key or credential_config.")
    elif not credential_config and (access_key_id and access_key_secret):
        # use explicit credential
        if security_token:
            credential_config = CredentialConfig(
                access_key_id=access_key_id,
                access_key_secret=access_key_secret,
                security_token=security_token,
                type=auth_constant.STS,
            )
        else:
            credential_config = CredentialConfig(
                access_key_id=access_key_id,
                access_key_secret=access_key_secret,
                type=auth_constant.ACCESS_KEY,
            )

    # override the config from default session
    default_session = get_default_session()
    if default_session:
        region_id = region_id or default_session.region_id
        workspace_id = workspace_id or default_session.workspace_id
        oss_bucket_name = oss_bucket_name or default_session.oss_bucket_name
        oss_endpoint = oss_endpoint or default_session.oss_endpoint
        credential_config = credential_config or default_session.credential_config
        network = network or default_session.network

    session = Session(
        region_id=region_id,
        credential_config=credential_config,
        oss_bucket_name=oss_bucket_name,
        oss_endpoint=oss_endpoint,
        workspace_id=workspace_id,
        network=network,
        **kwargs,
    )

    global _default_session
    _default_session = session
    return session


def get_default_session() -> "Session":
    """Get the default session used by the program.

    If the global default session is set, the function will try to initialize
    a session from config file.

    Returns:
        :class:`pai.session.Session`: The default session.

    """
    global _default_session
    if not _default_session:
        config = load_default_config_file()
        if config:
            _default_session = Session(**config)
        else:
            _default_session = _init_default_session_from_env()
    return _default_session


def _init_default_session_from_env() -> Optional["Session"]:
    credential_client = Session._get_default_credential_client()
    if not credential_client:
        logger.debug("Not found credential from default credential provider chain.")
        return

    # legacy region id env var in DSW
    region_id = os.getenv("dsw_region")
    region_id = os.getenv("REGION", region_id)
    if not region_id:
        logger.debug(
            "No region id found(env var: REGION or dsw_region), skip init default session"
        )
        return

    dsw_instance_id = os.getenv("DSW_INSTANCE_ID")
    if not dsw_instance_id:
        logger.debug(
            "No dsw instance id (env var: DSW_INSTANCE_ID) found, skip init default session"
        )
        return

    workspace_id = os.getenv("PAI_AI_WORKSPACE_ID")
    workspace_id = os.getenv("PAI_WORKSPACE_ID", workspace_id)

    network = (
        Network.VPC
        if is_domain_connectable(
            PAI_VPC_ENDPOINT.format(region_id),
            timeout=1,
        )
        else Network.PUBLIC
    )

    if dsw_instance_id and not workspace_id:
        logger.debug("Getting workspace id by dsw instance id: %s", dsw_instance_id)
        workspace_id = Session._get_workspace_id_by_dsw_instance_id(
            dsw_instance_id=dsw_instance_id,
            cred=credential_client,
            region_id=region_id,
            network=network,
        )
        if not workspace_id:
            logger.warning(
                "Failed to get workspace id by dsw instance id: %s", dsw_instance_id
            )
            return
    bucket_name, oss_endpoint = Session.get_default_oss_storage(
        workspace_id, credential_client, region_id, network
    )

    if not bucket_name:
        logger.warning(
            "Default OSS storage is not configured for the workspace: %s", workspace_id
        )

    sess = Session(
        region_id=region_id,
        workspace_id=workspace_id,
        credential_config=None,
        oss_bucket_name=bucket_name,
        oss_endpoint=oss_endpoint,
        network=network,
    )

    return sess


def load_default_config_file() -> Optional[Dict[str, Any]]:
    """Read config file"""

    config_path = DEFAULT_CONFIG_PATH
    if not os.path.exists(config_path):
        return

    with open(config_path, "r") as f:
        config = json.load(f)
        config = {
            key: value for key, value in config.items() if key in _DEFAULT_CONFIG_KEYS
        }

    # for backward compatibility, try to read credential from config file.
    if "access_key_id" in config and "access_key_secret" in config:
        config["credential_config"] = CredentialConfig(
            access_key_id=config["access_key_id"],
            access_key_secret=config["access_key_secret"],
            type=auth_constant.ACCESS_KEY,
        )

    return config


class Session(ResourceAPIsContainerMixin):
    """A class responsible for communicating with PAI services."""

    def __init__(
        self,
        region_id: str,
        workspace_id: Optional[str] = None,
        credential_config: Optional[CredentialConfig] = None,
        oss_bucket_name: Optional[str] = None,
        oss_endpoint: Optional[str] = None,
        **kwargs,
    ):
        """PAI Session Initializer.

        Args:
            credential_config (:class:`alibabacloud_credentials.models.Config`, optional):
                The credential config used to access the Alibaba Cloud.
            region_id (str): The ID of the Alibaba Cloud region where the service
                is located.
            workspace_id (str, optional): ID of the workspace used in the default
                session.
            oss_bucket_name (str, optional): The name of the OSS bucket used in the
                session.
            oss_endpoint (str, optional): The endpoint for the OSS bucket.
        """

        if not region_id:
            raise ValueError("Region ID must be provided.")

        self._credential_config = credential_config
        self._region_id = region_id
        self._workspace_id = str(workspace_id)
        self._oss_bucket_name = oss_bucket_name
        self._oss_endpoint = oss_endpoint

        header = kwargs.pop("header", None)
        network = kwargs.pop("network", None)
        runtime = kwargs.pop("runtime", None)
        if kwargs:
            logger.warning(
                "Unused arguments found in session initialization: %s", kwargs
            )
        super(Session, self).__init__(header=header, network=network, runtime=runtime)

    @property
    def region_id(self) -> str:
        return self._region_id

    @property
    def is_inner(self) -> bool:
        return self._region_id in INNER_REGION_IDS

    @property
    def oss_bucket_name(self) -> str:
        return self._oss_bucket_name

    @property
    def oss_endpoint(self) -> str:
        return self._oss_endpoint

    @property
    def credential_config(self) -> CredentialConfig:
        return self._credential_config

    @property
    def workspace_name(self):
        if hasattr(self, "_workspace_name") and self._workspace_name:
            return self._workspace_name

        if not self._workspace_id:
            raise ValueError("Workspace id is not set.")
        workspace_api_obj = self.workspace_api.get(workspace_id=self._workspace_id)
        self._workspace_name = workspace_api_obj["WorkspaceName"]
        return self._workspace_name

    @property
    def provider(self) -> str:
        caller_identity = self._acs_sts_client.get_caller_identity().body
        return caller_identity.account_id

    @property
    def workspace_id(self) -> str:
        """ID of the workspace used by the session."""
        return self._workspace_id

    @property
    def console_uri(self) -> str:
        """The web console URI for PAI service."""
        if self.is_inner:
            return "https://pai-next.alibaba-inc.com"
        else:
            return "https://pai.console.aliyun.com/console"

    def _init_oss_config(
        self,
    ):
        """Initialize a OssConfig instance."""
        if not self._oss_bucket_name:
            # If OSS bucket name is not provided, use the default OSS storage URI
            # that is configured for the workspace.
            default_oss_uri = self.workspace_api.get_default_storage_uri(
                self.workspace_id
            )
            if not default_oss_uri:
                raise RuntimeError(
                    "No default OSS URI is configured for the workspace."
                )
            oss_uri_obj = OssUriObj(default_oss_uri)
            self._oss_bucket_name = oss_uri_obj.bucket_name

        if not self._oss_endpoint:
            self._oss_endpoint = self._get_default_oss_endpoint()

    def _get_oss_auth(self):
        auth = oss2.ProviderAuth(
            credentials_provider=CredentialProviderWrapper(
                config=self._credential_config,
            )
        )
        return auth

    @property
    def oss_bucket(self):
        """A OSS2 bucket instance used by the session."""
        if not self._oss_bucket_name or not self._oss_endpoint:
            self._init_oss_config()
        oss_bucket = oss2.Bucket(
            auth=self._get_oss_auth(),
            endpoint=self._oss_endpoint,
            bucket_name=self._oss_bucket_name,
        )
        return oss_bucket

    def save_config(self, config_path=None):
        """Save the configuration of the session to a local file."""
        attrs = {key.lstrip("_"): value for key, value in vars(self).items()}
        config = {
            key: value
            for key, value in attrs.items()
            if key in _DEFAULT_CONFIG_KEYS and value is not None
        }

        config_path = config_path or DEFAULT_CONFIG_PATH
        os.makedirs(os.path.dirname(config_path), exist_ok=True)
        with open(config_path, "w") as f:
            f.write(json.dumps(config, indent=4))
        logger.info("Write PAI config succeed: config_path=%s" % config_path)

    def patch_oss_endpoint(self, oss_uri: str):
        oss_uri_obj = OssUriObj(oss_uri)
        if oss_uri_obj.endpoint:
            return oss_uri

        # patch endpoint using current OSS bucket endpoint.
        endpoint = self.oss_bucket.endpoint
        if endpoint.startswith("http://"):
            endpoint = endpoint.lstrip("http://")
        elif endpoint.startswith("https://"):
            endpoint = endpoint.lstrip("https://")
        return "oss://{bucket_name}.{endpoint}/{key}".format(
            bucket_name=oss_uri_obj.bucket_name,
            endpoint=endpoint,
            key=oss_uri_obj.object_key,
        )

    def _get_default_oss_endpoint(self) -> str:
        """Returns a default OSS endpoint."""

        # OSS Endpoint document:
        # https://help.aliyun.com/document_detail/31837.html
        internet_endpoint = "oss-{}.aliyuncs.com".format(self.region_id)
        internal_endpoint = "oss-{}-internal.aliyuncs.com".format(self.region_id)

        return (
            internet_endpoint
            if is_domain_connectable(internal_endpoint)
            else internet_endpoint
        )

    def get_oss_bucket(self, bucket_name: str, endpoint: str = None) -> oss2.Bucket:
        """Get a OSS bucket using the credentials of the session.

        Args:
            bucket_name (str): The name of the bucket.
            endpoint (str): Endpoint of the bucket.

        Returns:
            :class:`oss2.Bucket`: A OSS bucket instance.

        """
        endpoint = endpoint or self._oss_endpoint or self._get_default_oss_endpoint()
        oss_bucket = oss2.Bucket(
            auth=self._get_oss_auth(),
            endpoint=endpoint,
            bucket_name=bucket_name,
        )
        return oss_bucket

    @classmethod
    def get_storage_path_by_category(
        cls, category: str, dir_name: Optional[str] = None
    ) -> str:
        """Get an OSS storage path for the resource.

        Args:
            category (str): The category of the resource.
            dir_name (str, optional): The directory name of the resource.

        Returns:
            str: A OSS storage path.

        """
        dir_name = dir_name or datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        storage_path = posixpath.join("pai", category, dir_name).strip()

        if not storage_path.endswith("/"):
            storage_path += "/"
        return storage_path

    def is_supported_training_instance(self, instance_type: str) -> bool:
        """Check if the instance type is supported for training."""
        instance_generator = make_list_resource_iterator(self.job_api.list_ecs_specs)
        machine_spec = next(
            (
                item
                for item in instance_generator
                if item["InstanceType"] == instance_type
            ),
            None,
        )
        return bool(machine_spec)

    def is_gpu_training_instance(self, instance_type: str) -> bool:
        """Check if the instance type is GPU instance for training."""
        instance_generator = make_list_resource_iterator(self.job_api.list_ecs_specs)
        machine_spec = next(
            (
                item
                for item in instance_generator
                if item["InstanceType"] == instance_type
            ),
            None,
        )
        if not machine_spec:
            raise ValueError(
                f"Instance type {instance_type} is not supported for training job. "
                "Please provide a supported instance type."
            )
        return machine_spec["AcceleratorType"] == "GPU"

    def is_supported_inference_instance(self, instance_type: str) -> bool:
        """Check if the instance type is supported for inference."""
        res = self.service_api.describe_machine()["InstanceMetas"]
        spec = next(
            (item for item in res if item["InstanceType"] == instance_type), None
        )
        return bool(spec)

    def is_gpu_inference_instance(self, instance_type: str) -> bool:
        """Check if the instance type is GPU instance for inference."""
        res = self.service_api.describe_machine()["InstanceMetas"]
        spec = next(
            (item for item in res if item["InstanceType"] == instance_type), None
        )

        if not spec:
            raise ValueError(
                f"Instance type {instance_type} is not supported for deploying. "
                "Please provide a supported instance type."
            )
        return bool(spec["GPU"])

    @staticmethod
    def get_default_oss_storage(
        workspace_id: str, cred: CredentialClient, region_id: str, network: Network
    ) -> Tuple[Optional[str], Optional[str]]:
        acs_ws_client = ClientFactory.create_client(
            service_name=ServiceName.PAI_WORKSPACE,
            credential_client=cred,
            region_id=region_id,
            network=network,
        )
        workspace_api = WorkspaceAPI(
            acs_client=acs_ws_client,
        )
        resp = workspace_api.list_configs(
            workspace_id=workspace_id,
            config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI,
        )
        oss_storage_uri = next(
            (
                item["ConfigValue"]
                for item in resp["Configs"]
                if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI
            ),
            None,
        )

        # Default OSS storage uri is not set.
        if not oss_storage_uri:
            return None, None
        uri_obj = OssUriObj(oss_storage_uri)
        if network == Network.VPC:
            endpoint = "oss-{}-internal.aliyuncs.com".format(region_id)
        else:
            endpoint = "oss-{}.aliyuncs.com".format(region_id)
        return uri_obj.bucket_name, endpoint

    @staticmethod
    def _get_default_credential_client() -> Optional[CredentialClient]:
        try:
            # Initialize the credential client with default credential chain.
            # see: https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c
            return CredentialClient()
        except CredentialException:
            return

    @staticmethod
    def _get_workspace_id_by_dsw_instance_id(
        dsw_instance_id: str, cred: CredentialClient, region_id: str, network: Network
    ) -> Optional[str]:
        """Get workspace id by dsw instance id"""
        dsw_client = ClientFactory.create_client(
            service_name=ServiceName.PAI_DSW,
            credential_client=cred,
            region_id=region_id,
            network=network,
        )
        try:
            resp = dsw_client.get_instance(
                dsw_instance_id, request=GetInstanceRequest()
            )
            return resp.body.workspace_id
        except TeaException as e:
            logger.warning("Failed to get instance info by dsw instance id: %s", e)
            return
