import asyncio
import json
import logging
import os
from dataclasses import dataclass, field
from functools import reduce, lru_cache
from pathlib import Path
from typing import Any, Dict, List, Mapping, NewType, Optional, Union, TypeVar, Type
from urllib.parse import urlparse
from urllib.request import urlopen

from iact3.util import yaml, CustomSafeLoader
from alibabacloud_credentials.models import Config
from dataclasses_jsonschema import JsonSchemaMixin, ValidationError

from iact3.generate_params import ParamGenerator, IAC_NAME
from iact3.exceptions import Iact3Exception
from iact3.plugin.base_plugin import CredentialClient
from iact3.plugin.oss import OssPlugin
from iact3.plugin.ros import StackPlugin

LOG = logging.getLogger(__name__)

GENERAL_CONFIG_FILE = Path(f'~/.{IAC_NAME}.yml').expanduser().resolve()
DEFAULT_PROJECT_ROOT = Path('./').resolve()
DEFAULT_TEMPLATE_PATH = DEFAULT_PROJECT_ROOT
DEFAULT_CONFIG_FILE = f'.{IAC_NAME}.yml'
DEFAULT_OUTPUT_DIRECTORY = f'{IAC_NAME}_outputs'
OVERRIDES = f'.{IAC_NAME}_overrides.yml'
DEFAULT_AUTH_FILE = Path(f'~/.aliyun/config.json').expanduser().resolve()

CONFIG_KEYS = (
    GENERAL, PROJECT, TESTS
) = (
    'general', 'project', 'tests'
)

METADATA_KEYS = (
    AUTH, REGIONS, PARAMETERS, TAGS, OSS_CONFIG, TEMPLATE_CONFIG, ROLE_NAME, NAME
) = (
    'auth', 'regions', 'parameters', 'tags', 'oss_config', 'template_config', 'role_name', 'name'
)

TEMPLATE_CONFIG_ITEMS = (
    TEMPLATE_BODY, TEMPLATE_URL, TEMPLATE_ID, TEMPLATE_VERSION, TEMPLATE_LOCATION
) = (
    'template_body', 'template_url', 'template_id', 'template_version', 'template_location'
)

OSS_CONFIG_ITEMS = (
    BUCKET_NAME, BUCKET_REGION
) = (
    'bucket_name', 'bucket_region'
)


METADATA: Mapping[str, Mapping[str, Any]] = {
    AUTH: {
        'Description': 'Aliyun authentication section.',
        'Examples': [{
            'name': 'default',
            'location': '~/.aliyun/config.json'
        }]
    },
    REGIONS: {
        'Description': 'List of aliyun regions.',
        'Examples': [['cn-hangzhou', 'cn-beijing']]
    },
    PARAMETERS: {
        'Description': 'Parameter key-values to pass to template.',
        'Examples': [{
            'MyParameterKey1': 'MyParameterValue1',
            'MyParameterKey2': 'MyParameterValue2'
        }]
    },
    OSS_CONFIG: {
        'Description': 'Oss bucket configuration, include BucketName, BucketRegion and etc.',
        'Examples': [{
            'bucket_name': 'ExampleName',
            'bucket_region': 'cn-hangzhou'
        }]
    },
    TEMPLATE_CONFIG: {
        'Description': 'Oss bucket configuration, include BucketName, BucketRegion and etc.',
        'Examples': [{
            'template_body': '{"ROSTemplateFormatVersion": "2015-09-01"}',
            'template_url': 'oss://ros-template/demo',
            'template_id': '5ecd1e10-b0e9-4389-a565-e4c15efc****',
            'template_version': 'v1',
            'template_location': 'ros-template/'
        }]
    },
    TAGS: {
        'Description': 'TAGS to apply to template.',
        'Examples': [{
            'env': 'Product',
            'app': 'MyApp'
        }],
    },
    ROLE_NAME: {
        'Description': 'Role name to use while running test.',
        'Examples': ['my-test-role']
    }
}

# types
ParameterKey = NewType('ParameterKey', str)
TagKey = NewType('TagKey', str)
TagValue = NewType('TagValue', str)
Region = NewType('Region', str)
OssConfigKey = NewType('OssConfigKey', str)
OssConfigValue = NewType('OssConfigValue', str)


@dataclass
class Auth(JsonSchemaMixin, allow_additional_props=False):
    name: Optional[str] = field(default=None)
    location: Optional[str] = field(default=None)

    def __post_init__(self):
        self.credential = self._get_credential()

    def __hash__(self):
        return hash((self.name, self.location))

    @lru_cache
    def _get_credential(self) -> Union[CredentialClient, None]:
        file_path = Path(self.location).expanduser().resolve() if self.location else DEFAULT_AUTH_FILE
        if not file_path.is_file():
            return
        try:
            with open(file_path, 'r', encoding='utf-8') as file_handle:
                config = json.load(file_handle)
            default = config.get('current')
            name = self.name or default
            if not name:
                return
            specified_profile = None
            profiles = config.get('profiles')
            for profile in profiles:
                if profile.get('name') == name:
                    specified_profile = profile
                    break
        except Exception as e:
            LOG.debug(str(e), exc_info=True)
            return

        if not specified_profile:
            return
        if specified_profile.get('mode') == 'AK':
            specified_config = Config(
                type='access_key',
                access_key_id=specified_profile.get('access_key_id'),
                access_key_secret=specified_profile.get('access_key_secret')
            )
        elif specified_profile.get('mode') == 'StsToken':
            specified_config = Config(
                type='sts',
                access_key_id=specified_profile.get('access_key_id'),
                access_key_secret=specified_profile.get('access_key_secret'),
                security_token=specified_profile.get('sts_token')
            )
        elif specified_profile.get('mode') == 'RamRoleArn':
            specified_config = Config(
                type='ram_role_arn',
                access_key_id=specified_profile.get('access_key_id'),
                access_key_secret=specified_profile.get('access_key_secret'),
                security_token=specified_profile.get('sts_token'),
                role_arn=specified_profile.get('ram_role_arn'),
                role_session_name=specified_profile.get('ram_session_name'),
                policy=specified_profile.get('policy', ''),
                role_session_expiration=specified_profile.get('expired_seconds', 900)
            )
        elif specified_profile.get('mode') == 'EcsRamRole':
            specified_config = Config(
                type='ecs_ram_role',
                role_name=specified_profile.get('ram_role_name')
            )
        else:
            return
        return CredentialClient(config=specified_config)


@dataclass
class OssCallbackConfig(JsonSchemaMixin, allow_additional_props=False):
    callback_url: Optional[str] = field(default=None)
    callback_host: Optional[str] = field(default=None)
    callback_body: Optional[str] = field(default=None)
    callback_body_type: Optional[str] = field(default=None)
    callback_var_params: Optional[dict] = field(default=None)


@dataclass
class OssConfig(JsonSchemaMixin, allow_additional_props=False):
    bucket_name: Optional[str] = field(default=None)
    bucket_region: Optional[str] = field(default=None)
    object_prefix: Optional[str] = field(default=None)
    callback_params: Optional[OssCallbackConfig] = field(default_factory=OssCallbackConfig)

    def validate_bucket(self, plugin: OssPlugin):
        if not plugin.bucket_exist():
            raise Iact3Exception(f'oss bucket {self.bucket_name} in {self.bucket_region} region is not exist')


@dataclass
class TemplateConfig(JsonSchemaMixin, allow_additional_props=False):
    template_body: Optional[str] = field(default=None)
    template_url: Optional[str] = field(default=None)
    template_id: Optional[str] = field(default=None)
    template_version: Optional[str] = field(default=None)
    template_location: Optional[str] = field(default=None)
    tf_version: Optional[str] = field(default=None)

    def __hash__(self):
        return hash((
            self.template_body, self.template_url,
            self.template_id, self.template_version,
            self.template_location))

    def _get_tf_version(self):
        if not self.tf_version:
            return 'Aliyun::Terraform-v1.2'
        else:
            return self.tf_version

    def _get_template_location(self, template: Union[str, Path] = None) -> any:
        suffix = ('json', 'yaml', 'yml')
        template_path = Path(template).resolve() if template else DEFAULT_TEMPLATE_PATH
        if template_path.is_file():
            return template_path
        template_file = None
        for su in suffix:
            template_file = next(template_path.glob(f'*.template.{su}'), None)
            if template_file is not None:
                break
        if template_file is None:
            tf_content = {}
            for path, dirs, files in os.walk(template_path):
                for file in files:
                    if not file.endswith('.tf'):
                        continue
                    file_path = os.path.join(path, file)
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                    tf_content[file_path] = content

            if not tf_content:
                return

            common_path = os.path.commonpath(list(tf_content))
            work_space = {
                p.replace(common_path + '/', ''): value for p, value in tf_content.items()
            }
            return {
                'ROSTemplateFormatVersion': '2015-09-01',
                'Transform': self._get_tf_version(),
                'Workspace': work_space
            }
        return template_file

    @lru_cache
    def generate_template_args(self) -> dict:
        result = self.to_dict()
        if self.template_id or self.template_body:
            return result

        if self.template_url:
            template_url = self.template_url
            components = urlparse(template_url)
            if components.scheme in ('oss', 'http', 'https'):
                return result
            elif components.scheme == 'file':
                try:
                    tpl_body = urlopen(template_url).read()
                    result[TEMPLATE_BODY] = tpl_body.decode('utf-8')
                    result.pop(TEMPLATE_URL)
                    return result
                except Exception as ex:
                    raise Iact3Exception(f'failed to retrieve {template_url}: {ex}')
            else:
                raise Iact3Exception(f'template url {template_url} is not legally.')

        tpl_path = result.pop(TEMPLATE_LOCATION, None)
        tpl_item = self._get_template_location(tpl_path)
        if tpl_item is None:
            msg = f'Could not find template in {tpl_path or DEFAULT_TEMPLATE_PATH} directory' \
                  f'Template files need end with .template.json or .template.yaml or .template.yml.'
            raise Iact3Exception(msg)
        elif isinstance(tpl_item, dict):
            result[TEMPLATE_BODY] = json.dumps(tpl_item)
        else:
            file_path = Path(tpl_item).expanduser().resolve()
            if not file_path.is_file():
                return result
            try:
                with open(str(file_path), 'r', encoding='utf-8') as file_handle:
                    tpl_body = yaml.load(file_handle, Loader=CustomSafeLoader)
                    result[TEMPLATE_BODY] = json.dumps(tpl_body)
            except Exception as e:
                LOG.debug(str(e), exc_info=True)
                raise Iact3Exception(f'can not find a template: {str(e)}')
        return result


@dataclass
class GeneralConfig(JsonSchemaMixin):
    '''General configuration settings.'''

    auth: Auth = field(
        default_factory=Auth, metadata=METADATA[AUTH])
    regions: Optional[List[Region]] = field(
        default_factory=list, metadata=METADATA[REGIONS])
    parameters: Optional[Dict[ParameterKey, Any]] = field(
        default_factory=dict, metadata=METADATA[PARAMETERS])
    parameters_order: Optional[List[str]] = field(default_factory=list)
    tags: Optional[Dict[TagKey, TagValue]] = field(
        default_factory=dict, metadata=METADATA[TAGS])
    oss_config: Optional[OssConfig] = field(
        default_factory=OssConfig, metadata=METADATA[OSS_CONFIG])


@dataclass
class ProjectConfig(GeneralConfig):
    '''Project specific configuration section'''

    name: Optional[str] = field(default='iact3-default-project-name')
    role_name: Optional[str] = field(
        default_factory=str, metadata=METADATA[ROLE_NAME])
    template_config: TemplateConfig = field(
        default_factory=TemplateConfig, metadata=METADATA[TEMPLATE_CONFIG])


@dataclass
class TestConfig(ProjectConfig):
    '''Test specific configuration section.'''

    def __post_init__(self):
        self.test_name = None
        self.region = None
        self.error = None


T = TypeVar('T', bound='BaseConfig')


@dataclass
class BaseConfig(JsonSchemaMixin):
    general: GeneralConfig = field(default_factory=GeneralConfig)
    project: ProjectConfig = field(default_factory=ProjectConfig)
    tests: Dict[str, TestConfig] = field(default_factory=dict)

    def __post_init__(self):
        self._all_regions = None
        self._credential = None

    @classmethod
    def merge(cls, base: Dict, new: Dict) -> Dict:
        if base is None:
            base = {}
        if new is None:
            new = {}
        result = base.copy()
        for item, value in new.items():
            if item == PARAMETERS:
                value.update(result.get(item, {}))
                result[item] = value
                continue
            if item not in result or not isinstance(value, dict):
                result[item] = value
                continue
            result[item] = cls.merge(result[item], value)
        return result

    @classmethod
    def generate_from_file(cls, file_path: Path, fail_ok=True, validate=True) -> dict:
        config_dict = {}
        if not file_path.is_file() and fail_ok:
            return config_dict
        try:
            with open(str(file_path), 'r', encoding='utf-8') as file_handle:
                config_dict = yaml.load(file_handle, Loader=CustomSafeLoader)
            if validate:
                try:
                    cls.from_dict(config_dict)
                except ValidationError as e:
                    LOG.warning(f'config from {file_path} is illegal.')
                    LOG.debug(str(e), exc_info=True)
                    if not fail_ok:
                        raise e
            return config_dict
        except Exception as e:
            try:
                with open(str(file_path), 'r', encoding='utf-8') as file_handle:
                    file_content = file_handle.read()
            except Exception as ex:
                LOG.warning(f'failed to load config from {file_path}')
                raise ex

            LOG.warning(f'failed to load config from {file_path}, file content is {file_content}')
            LOG.debug(str(e), exc_info=True)
            if not fail_ok:
                raise e
        return config_dict

    @classmethod
    def create(cls: Type[T],
               global_config_path: Path = GENERAL_CONFIG_FILE,
               project_config_file: Path = DEFAULT_CONFIG_FILE,
               args: Optional[dict] = None,
               project_path: str = None,
               fail_ok: bool = False) -> T:
        if not project_path:
            project_path = DEFAULT_PROJECT_ROOT
        project_root: Path = Path(project_path).expanduser().resolve()
        project_config_path = project_root / project_config_file
        sources = [
            cls.generate_from_file(global_config_path),
            cls.generate_from_file(project_config_path, fail_ok=fail_ok),
            args or {}
        ]
        config = reduce(cls.merge, sources)
        general_config = config.get(GENERAL, {})
        merged_project_config = cls.merge(general_config, config.get(PROJECT, {}))
        merged_test_configs = {
            key: cls.merge(merged_project_config, value) for key, value in config.get(TESTS, {}).items()
        }
        return cls.from_dict({
            GENERAL: general_config,
            PROJECT: merged_project_config,
            TESTS: merged_test_configs
        })

    async def get_all_configs(self, test_names: str = None):
        results = []
        base = self.tests
        test_names = test_names.split(',') if test_names else []
        param_tasks = []
        for name, config in base.items():
            if test_names and name not in test_names:
                continue
            regions = [region.lower() for region in config.regions]
            if 'all' in regions or not regions:
                all_regions = await self._get_test_regions()
                regions = all_regions

            template_args = config.template_config.generate_template_args()
            if TEMPLATE_LOCATION in template_args:
                template_args.pop(TEMPLATE_LOCATION)
            config.template_config = TemplateConfig.from_dict(template_args)

            for region in regions:
                region_config = TestConfig.from_dict(config.to_dict())
                region_config.region = region
                region_config.test_name = name
                oss_config = region_config.oss_config
                bucket_name = oss_config.bucket_name
                if bucket_name:
                    plugin = OssPlugin(
                        region_id=region,
                        bucket_name=bucket_name,
                        credential=region_config.auth.credential
                    )
                    oss_config.validate_bucket(plugin)
                resolved_parameters_task = ParamGenerator.result(region_config)
                param_tasks.append(asyncio.create_task(resolved_parameters_task))
                results.append(region_config)
        resolved_parameters = await asyncio.gather(*param_tasks)
        for config, params in zip(results, resolved_parameters):
            assert config.test_name == params.name
            assert config.region == params.region
            if params.error:
                config.error = params.error
            config.parameters = params.parameters
        return results

    def get_oss_config(self):
        oss_config = self.project.oss_config
        return oss_config.bucket_name, oss_config.bucket_region

    async def _get_test_regions(self):
        if self._all_regions is None:
            plugin = StackPlugin('cn-hangzhou', self.general.auth.credential)
            self._all_regions = await plugin.get_regions()
        return self._all_regions

    def get_credential(self):
        if self._credential is None:
            self._credential = self.project.auth.credential
        return self._credential
