iact3/config.py (432 lines of code) (raw):
import asyncio
import json
import logging
import os
from dataclasses import dataclass, field
from functools import reduce, lru_cache
from pathlib import Path
from typing import Any, Dict, List, Mapping, NewType, Optional, Union, TypeVar, Type
from urllib.parse import urlparse
from urllib.request import urlopen
from iact3.util import yaml, CustomSafeLoader
from alibabacloud_credentials.models import Config
from dataclasses_jsonschema import JsonSchemaMixin, ValidationError
from iact3.generate_params import ParamGenerator, IAC_NAME
from iact3.exceptions import Iact3Exception
from iact3.plugin.base_plugin import CredentialClient
from iact3.plugin.oss import OssPlugin
from iact3.plugin.ros import StackPlugin
LOG = logging.getLogger(__name__)
GENERAL_CONFIG_FILE = Path(f'~/.{IAC_NAME}.yml').expanduser().resolve()
DEFAULT_PROJECT_ROOT = Path('./').resolve()
DEFAULT_TEMPLATE_PATH = DEFAULT_PROJECT_ROOT
DEFAULT_CONFIG_FILE = f'.{IAC_NAME}.yml'
DEFAULT_OUTPUT_DIRECTORY = f'{IAC_NAME}_outputs'
OVERRIDES = f'.{IAC_NAME}_overrides.yml'
DEFAULT_AUTH_FILE = Path(f'~/.aliyun/config.json').expanduser().resolve()
CONFIG_KEYS = (
GENERAL, PROJECT, TESTS
) = (
'general', 'project', 'tests'
)
METADATA_KEYS = (
AUTH, REGIONS, PARAMETERS, TAGS, OSS_CONFIG, TEMPLATE_CONFIG, ROLE_NAME, NAME
) = (
'auth', 'regions', 'parameters', 'tags', 'oss_config', 'template_config', 'role_name', 'name'
)
TEMPLATE_CONFIG_ITEMS = (
TEMPLATE_BODY, TEMPLATE_URL, TEMPLATE_ID, TEMPLATE_VERSION, TEMPLATE_LOCATION
) = (
'template_body', 'template_url', 'template_id', 'template_version', 'template_location'
)
OSS_CONFIG_ITEMS = (
BUCKET_NAME, BUCKET_REGION
) = (
'bucket_name', 'bucket_region'
)
METADATA: Mapping[str, Mapping[str, Any]] = {
AUTH: {
'Description': 'Aliyun authentication section.',
'Examples': [{
'name': 'default',
'location': '~/.aliyun/config.json'
}]
},
REGIONS: {
'Description': 'List of aliyun regions.',
'Examples': [['cn-hangzhou', 'cn-beijing']]
},
PARAMETERS: {
'Description': 'Parameter key-values to pass to template.',
'Examples': [{
'MyParameterKey1': 'MyParameterValue1',
'MyParameterKey2': 'MyParameterValue2'
}]
},
OSS_CONFIG: {
'Description': 'Oss bucket configuration, include BucketName, BucketRegion and etc.',
'Examples': [{
'bucket_name': 'ExampleName',
'bucket_region': 'cn-hangzhou'
}]
},
TEMPLATE_CONFIG: {
'Description': 'Oss bucket configuration, include BucketName, BucketRegion and etc.',
'Examples': [{
'template_body': '{"ROSTemplateFormatVersion": "2015-09-01"}',
'template_url': 'oss://ros-template/demo',
'template_id': '5ecd1e10-b0e9-4389-a565-e4c15efc****',
'template_version': 'v1',
'template_location': 'ros-template/'
}]
},
TAGS: {
'Description': 'TAGS to apply to template.',
'Examples': [{
'env': 'Product',
'app': 'MyApp'
}],
},
ROLE_NAME: {
'Description': 'Role name to use while running test.',
'Examples': ['my-test-role']
}
}
# types
ParameterKey = NewType('ParameterKey', str)
TagKey = NewType('TagKey', str)
TagValue = NewType('TagValue', str)
Region = NewType('Region', str)
OssConfigKey = NewType('OssConfigKey', str)
OssConfigValue = NewType('OssConfigValue', str)
@dataclass
class Auth(JsonSchemaMixin, allow_additional_props=False):
name: Optional[str] = field(default=None)
location: Optional[str] = field(default=None)
def __post_init__(self):
self.credential = self._get_credential()
def __hash__(self):
return hash((self.name, self.location))
@lru_cache
def _get_credential(self) -> Union[CredentialClient, None]:
file_path = Path(self.location).expanduser().resolve() if self.location else DEFAULT_AUTH_FILE
if not file_path.is_file():
return
try:
with open(file_path, 'r', encoding='utf-8') as file_handle:
config = json.load(file_handle)
default = config.get('current')
name = self.name or default
if not name:
return
specified_profile = None
profiles = config.get('profiles')
for profile in profiles:
if profile.get('name') == name:
specified_profile = profile
break
except Exception as e:
LOG.debug(str(e), exc_info=True)
return
if not specified_profile:
return
if specified_profile.get('mode') == 'AK':
specified_config = Config(
type='access_key',
access_key_id=specified_profile.get('access_key_id'),
access_key_secret=specified_profile.get('access_key_secret')
)
elif specified_profile.get('mode') == 'StsToken':
specified_config = Config(
type='sts',
access_key_id=specified_profile.get('access_key_id'),
access_key_secret=specified_profile.get('access_key_secret'),
security_token=specified_profile.get('sts_token')
)
elif specified_profile.get('mode') == 'RamRoleArn':
specified_config = Config(
type='ram_role_arn',
access_key_id=specified_profile.get('access_key_id'),
access_key_secret=specified_profile.get('access_key_secret'),
security_token=specified_profile.get('sts_token'),
role_arn=specified_profile.get('ram_role_arn'),
role_session_name=specified_profile.get('ram_session_name'),
policy=specified_profile.get('policy', ''),
role_session_expiration=specified_profile.get('expired_seconds', 900)
)
elif specified_profile.get('mode') == 'EcsRamRole':
specified_config = Config(
type='ecs_ram_role',
role_name=specified_profile.get('ram_role_name')
)
else:
return
return CredentialClient(config=specified_config)
@dataclass
class OssCallbackConfig(JsonSchemaMixin, allow_additional_props=False):
callback_url: Optional[str] = field(default=None)
callback_host: Optional[str] = field(default=None)
callback_body: Optional[str] = field(default=None)
callback_body_type: Optional[str] = field(default=None)
callback_var_params: Optional[dict] = field(default=None)
@dataclass
class OssConfig(JsonSchemaMixin, allow_additional_props=False):
bucket_name: Optional[str] = field(default=None)
bucket_region: Optional[str] = field(default=None)
object_prefix: Optional[str] = field(default=None)
callback_params: Optional[OssCallbackConfig] = field(default_factory=OssCallbackConfig)
def validate_bucket(self, plugin: OssPlugin):
if not plugin.bucket_exist():
raise Iact3Exception(f'oss bucket {self.bucket_name} in {self.bucket_region} region is not exist')
@dataclass
class TemplateConfig(JsonSchemaMixin, allow_additional_props=False):
template_body: Optional[str] = field(default=None)
template_url: Optional[str] = field(default=None)
template_id: Optional[str] = field(default=None)
template_version: Optional[str] = field(default=None)
template_location: Optional[str] = field(default=None)
tf_version: Optional[str] = field(default=None)
def __hash__(self):
return hash((
self.template_body, self.template_url,
self.template_id, self.template_version,
self.template_location))
def _get_tf_version(self):
if not self.tf_version:
return 'Aliyun::Terraform-v1.2'
else:
return self.tf_version
def _get_template_location(self, template: Union[str, Path] = None) -> any:
suffix = ('json', 'yaml', 'yml')
template_path = Path(template).resolve() if template else DEFAULT_TEMPLATE_PATH
if template_path.is_file():
return template_path
template_file = None
for su in suffix:
template_file = next(template_path.glob(f'*.template.{su}'), None)
if template_file is not None:
break
if template_file is None:
tf_content = {}
for path, dirs, files in os.walk(template_path):
for file in files:
if not file.endswith('.tf'):
continue
file_path = os.path.join(path, file)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tf_content[file_path] = content
if not tf_content:
return
common_path = os.path.commonpath(list(tf_content))
work_space = {
p.replace(common_path + '/', ''): value for p, value in tf_content.items()
}
return {
'ROSTemplateFormatVersion': '2015-09-01',
'Transform': self._get_tf_version(),
'Workspace': work_space
}
return template_file
@lru_cache
def generate_template_args(self) -> dict:
result = self.to_dict()
if self.template_id or self.template_body:
return result
if self.template_url:
template_url = self.template_url
components = urlparse(template_url)
if components.scheme in ('oss', 'http', 'https'):
return result
elif components.scheme == 'file':
try:
tpl_body = urlopen(template_url).read()
result[TEMPLATE_BODY] = tpl_body.decode('utf-8')
result.pop(TEMPLATE_URL)
return result
except Exception as ex:
raise Iact3Exception(f'failed to retrieve {template_url}: {ex}')
else:
raise Iact3Exception(f'template url {template_url} is not legally.')
tpl_path = result.pop(TEMPLATE_LOCATION, None)
tpl_item = self._get_template_location(tpl_path)
if tpl_item is None:
msg = f'Could not find template in {tpl_path or DEFAULT_TEMPLATE_PATH} directory' \
f'Template files need end with .template.json or .template.yaml or .template.yml.'
raise Iact3Exception(msg)
elif isinstance(tpl_item, dict):
result[TEMPLATE_BODY] = json.dumps(tpl_item)
else:
file_path = Path(tpl_item).expanduser().resolve()
if not file_path.is_file():
return result
try:
with open(str(file_path), 'r', encoding='utf-8') as file_handle:
tpl_body = yaml.load(file_handle, Loader=CustomSafeLoader)
result[TEMPLATE_BODY] = json.dumps(tpl_body)
except Exception as e:
LOG.debug(str(e), exc_info=True)
raise Iact3Exception(f'can not find a template: {str(e)}')
return result
@dataclass
class GeneralConfig(JsonSchemaMixin):
'''General configuration settings.'''
auth: Auth = field(
default_factory=Auth, metadata=METADATA[AUTH])
regions: Optional[List[Region]] = field(
default_factory=list, metadata=METADATA[REGIONS])
parameters: Optional[Dict[ParameterKey, Any]] = field(
default_factory=dict, metadata=METADATA[PARAMETERS])
parameters_order: Optional[List[str]] = field(default_factory=list)
tags: Optional[Dict[TagKey, TagValue]] = field(
default_factory=dict, metadata=METADATA[TAGS])
oss_config: Optional[OssConfig] = field(
default_factory=OssConfig, metadata=METADATA[OSS_CONFIG])
@dataclass
class ProjectConfig(GeneralConfig):
'''Project specific configuration section'''
name: Optional[str] = field(default='iact3-default-project-name')
role_name: Optional[str] = field(
default_factory=str, metadata=METADATA[ROLE_NAME])
template_config: TemplateConfig = field(
default_factory=TemplateConfig, metadata=METADATA[TEMPLATE_CONFIG])
@dataclass
class TestConfig(ProjectConfig):
'''Test specific configuration section.'''
def __post_init__(self):
self.test_name = None
self.region = None
self.error = None
T = TypeVar('T', bound='BaseConfig')
@dataclass
class BaseConfig(JsonSchemaMixin):
general: GeneralConfig = field(default_factory=GeneralConfig)
project: ProjectConfig = field(default_factory=ProjectConfig)
tests: Dict[str, TestConfig] = field(default_factory=dict)
def __post_init__(self):
self._all_regions = None
self._credential = None
@classmethod
def merge(cls, base: Dict, new: Dict) -> Dict:
if base is None:
base = {}
if new is None:
new = {}
result = base.copy()
for item, value in new.items():
if item == PARAMETERS:
value.update(result.get(item, {}))
result[item] = value
continue
if item not in result or not isinstance(value, dict):
result[item] = value
continue
result[item] = cls.merge(result[item], value)
return result
@classmethod
def generate_from_file(cls, file_path: Path, fail_ok=True, validate=True) -> dict:
config_dict = {}
if not file_path.is_file() and fail_ok:
return config_dict
try:
with open(str(file_path), 'r', encoding='utf-8') as file_handle:
config_dict = yaml.load(file_handle, Loader=CustomSafeLoader)
if validate:
try:
cls.from_dict(config_dict)
except ValidationError as e:
LOG.warning(f'config from {file_path} is illegal.')
LOG.debug(str(e), exc_info=True)
if not fail_ok:
raise e
return config_dict
except Exception as e:
try:
with open(str(file_path), 'r', encoding='utf-8') as file_handle:
file_content = file_handle.read()
except Exception as ex:
LOG.warning(f'failed to load config from {file_path}')
raise ex
LOG.warning(f'failed to load config from {file_path}, file content is {file_content}')
LOG.debug(str(e), exc_info=True)
if not fail_ok:
raise e
return config_dict
@classmethod
def create(cls: Type[T],
global_config_path: Path = GENERAL_CONFIG_FILE,
project_config_file: Path = DEFAULT_CONFIG_FILE,
args: Optional[dict] = None,
project_path: str = None,
fail_ok: bool = False) -> T:
if not project_path:
project_path = DEFAULT_PROJECT_ROOT
project_root: Path = Path(project_path).expanduser().resolve()
project_config_path = project_root / project_config_file
sources = [
cls.generate_from_file(global_config_path),
cls.generate_from_file(project_config_path, fail_ok=fail_ok),
args or {}
]
config = reduce(cls.merge, sources)
general_config = config.get(GENERAL, {})
merged_project_config = cls.merge(general_config, config.get(PROJECT, {}))
merged_test_configs = {
key: cls.merge(merged_project_config, value) for key, value in config.get(TESTS, {}).items()
}
return cls.from_dict({
GENERAL: general_config,
PROJECT: merged_project_config,
TESTS: merged_test_configs
})
async def get_all_configs(self, test_names: str = None):
results = []
base = self.tests
test_names = test_names.split(',') if test_names else []
param_tasks = []
for name, config in base.items():
if test_names and name not in test_names:
continue
regions = [region.lower() for region in config.regions]
if 'all' in regions or not regions:
all_regions = await self._get_test_regions()
regions = all_regions
template_args = config.template_config.generate_template_args()
if TEMPLATE_LOCATION in template_args:
template_args.pop(TEMPLATE_LOCATION)
config.template_config = TemplateConfig.from_dict(template_args)
for region in regions:
region_config = TestConfig.from_dict(config.to_dict())
region_config.region = region
region_config.test_name = name
oss_config = region_config.oss_config
bucket_name = oss_config.bucket_name
if bucket_name:
plugin = OssPlugin(
region_id=region,
bucket_name=bucket_name,
credential=region_config.auth.credential
)
oss_config.validate_bucket(plugin)
resolved_parameters_task = ParamGenerator.result(region_config)
param_tasks.append(asyncio.create_task(resolved_parameters_task))
results.append(region_config)
resolved_parameters = await asyncio.gather(*param_tasks)
for config, params in zip(results, resolved_parameters):
assert config.test_name == params.name
assert config.region == params.region
if params.error:
config.error = params.error
config.parameters = params.parameters
return results
def get_oss_config(self):
oss_config = self.project.oss_config
return oss_config.bucket_name, oss_config.bucket_region
async def _get_test_regions(self):
if self._all_regions is None:
plugin = StackPlugin('cn-hangzhou', self.general.auth.credential)
self._all_regions = await plugin.get_regions()
return self._all_regions
def get_credential(self):
if self._credential is None:
self._credential = self.project.auth.credential
return self._credential