source/idea/idea-sdk/src/ideasdk/context/soca_cli_context.py (300 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance # with the License. A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions # and limitations under the License. from ideasdk.context import SocaContext, SocaContextOptions from ideasdk.artwork import ascii_banner from ideasdk.shell import ShellInvoker from ideasdk.utils import Utils from ideasdk.client import SocaClient, SocaClientOptions from ideadatamodel import ( exceptions, errorcodes, SocaKeyValue, SocaUserInputChoice, SocaUserInputParamMetadata, SocaUserInputSectionMetadata, SocaUserInputModuleMetadata, SocaInputParamSpec ) from ideasdk.aws import AwsClientProvider, AWSUtil, AWSClientProviderOptions, AwsResources from ideasdk.user_input.framework import ( SocaUserInputArgs, SocaUserInputParamRegistry, SocaUserInputModule, SocaUserInputSection, SocaPromptRegistry, SocaPrompt ) from typing import Any, Union, Optional, Dict, List from pydantic import BaseModel import os import sys from rich.console import Console from rich.text import Text from rich.rule import Rule from rich.markdown import Markdown from rich.markup import escape from rich.style import Style from questionary import unsafe_prompt, Choice from contextlib import contextmanager import arrow if sys.version_info >= (3, 8): from typing import Literal else: from typing_extensions import Literal # pragma: no cover class SocaCliContext(SocaContext): def __init__(self, options: Optional[SocaContextOptions] = None, api_context_path: Optional[str] = None, unix_socket_timeout: Optional[int] = None): super().__init__( options=options ) self._shell = ShellInvoker() self._console = Console() self._api_context_path = api_context_path self._unix_socket_client: Optional[SocaClient] = None self._aws_resources = AwsResources(context=self, aws=self._aws, aws_util=self._aws_util) self._unix_socket_timeout: Optional[int] = unix_socket_timeout @property def shell(self) -> ShellInvoker: return self._shell @property def user_home(self) -> str: return os.path.expanduser('~') @contextmanager def cd(self, path: str): self._shell.cwd = path try: yield finally: self._shell.cwd = None @staticmethod def is_windows() -> bool: return os.name == 'nt' @staticmethod def is_root() -> bool: return os.geteuid() == 0 def check_root_access(self): if not self.is_root(): raise exceptions.soca_exception( error_code=errorcodes.NOT_AUTHORIZED, message='Access Denied: root access is required to use resctl.' ) # BEGIN: STDOUT FUNCTIONS @property def console(self) -> Console: return self._console def print_banner(self, meta_info: List[SocaKeyValue] = None): if self.is_windows(): # todo - need to investigate why rich banner printing on windows gives index out of bounds error banner = ascii_banner(meta_info=meta_info, rich=False) print(banner) else: banner = ascii_banner(meta_info=meta_info, rich=True) self.console.print(banner) self.new_line() def spinner(self, message: Any): return self.console.status(status=message) def print(self, message: Any, style=None, end=f'{os.linesep}', new_line=False, highlight=True, **kwargs): if Utils.is_not_empty(message): try: self.console.print(message, end=end, style=style, highlight=highlight, **kwargs) except BlockingIOError: print(message, end='\r\n') if new_line: self.new_line() def print_title(self, message: str, new_line=False): self.print(Text(text=message, style=Style(bold=True, italic=True, color='bright_white')), new_line=new_line) def new_line(self, count: int = 1): self.console.line(count) def print_rule(self, title: Optional[str] = None, align: Literal["left", "center", "right"] = 'left', style=None): if Utils.is_empty(title): text = '' else: text = Text(text=title, style=Style(bold=True, italic=True, color='bright_white')) if style is None: style = 'rule.line' self.print(Rule(text, align=align, style=style)) def print_json(self, payload: Union[str, BaseModel, Any]): if isinstance(payload, str): json_content = payload else: json_content = Utils.to_json(payload) self.console.print_json(json_content) @staticmethod def get_label(label: str, style: str) -> Text: text = Text(label) text.stylize(style) return text def echo(self, message: str, end=f'{os.linesep}', highlight=False): self.console.out(message, end=end, highlight=highlight) def debug(self, message: Any): if isinstance(message, str): message = f'[DEBUG] {message}' self.console.print(escape(message), style='dim') def success(self, message: Any): self.console.print(escape(message), style='bold green') def warning(self, message=None): self.console.print(escape(message), style='bold yellow') def info(self, message=None): self.console.print(escape(message), style='cyan') def error(self, message=Any): self.console.print(escape(message), style='bold red') def exception(self, message=None, markdown=False, error_code=errorcodes.CLI_ERROR, ref=None): if markdown: md_message = Markdown(message) self.console.print(md_message) raise SystemExit else: return exceptions.SocaException( error_code=error_code, message=message, ref=ref ) def get_log_message(self, tag: str = None, message: Optional[Union[str, Dict]] = None) -> str: if Utils.is_not_empty(tag): log = f'[{arrow.now()}] ({tag}) ' else: log = '' if message is not None: if isinstance(message, str): log += message.rstrip() else: log += Utils.to_json(message) return log def log(self, tag: str = None, message: Optional[Union[str, Dict]] = None): log_message = self.get_log_message(tag, message) if Utils.is_empty(log_message): return self.console.out(log_message.rstrip()) @staticmethod def prompt(message: str, default: Union[bool, str, SocaUserInputChoice] = None, auto_enter=True, icon='?', choices: List[Union[str, SocaUserInputChoice]] = None) -> Union[bool, str]: if Utils.is_empty(choices): if default is None: default = False result = unsafe_prompt(questions=[{ 'type': 'confirm', 'name': 'result', 'message': message, 'default': default, 'auto_enter': auto_enter, 'qmark': icon }]) return Utils.get_value_as_bool('result', result, default) else: choices_ = [] default_choice_ = None for choice in choices: if isinstance(choice, str): choice_ = Choice( title=choice, value=choice ) choices_.append(choice_) if default is not None: if default == choice: default_choice_ = choice_ else: choice_ = Choice( title=choice.title, value=choice.value, disabled=choice.disabled, checked=choice.checked ) choices_.append(choice_) if default is not None: if default.value == choice.value: default_choice_ = choice_ if default_choice_ is None: default_choice_ = choices_[0] result = unsafe_prompt(questions=[{ 'type': 'select', 'name': 'result', 'message': message, 'choices': choices_, 'default': default_choice_, 'qmark': icon }]) return Utils.get_value_as_string('result', result) @property def unix_socket_client(self) -> SocaClient: if Utils.is_empty(self._api_context_path): raise exceptions.general_exception('API Context Path not found') if self._unix_socket_client is None: self._unix_socket_client = SocaClient( context=self, options=SocaClientOptions( enable_logging=False, endpoint=f'http://localhost{self._api_context_path}', unix_socket='/run/idea.sock', timeout=self._unix_socket_timeout ) ) return self._unix_socket_client def aws_init(self, aws_profile: Optional[str] = None, aws_region: Optional[str] = None): try: self._aws = AwsClientProvider( options=AWSClientProviderOptions( profile=aws_profile, region=aws_region )) self._aws_util = AWSUtil(context=self, aws=self._aws) self._aws_resources = AwsResources( context=self, aws=self._aws, aws_util=self._aws_util ) except Exception as e: self.aws_util().handle_aws_exception(e) def refresh_aws_credentials(self): if not self.aws().are_credentials_expired(): return self.aws_init(self._aws.aws_profile(), self._aws.aws_region()) def get_aws_resources(self): return self._aws_resources def aws(self) -> AwsClientProvider: if self._aws.are_credentials_expired(): self.aws_init(self._options.aws_profile, self._options.aws_region) return self._aws def ask(self, questions: List[SocaUserInputParamMetadata], title: str = None, description: str = None, module_name: str = None) -> Dict: if Utils.is_empty(module_name): module_name = 'default' spec = SocaInputParamSpec( params=questions, modules=[ SocaUserInputModuleMetadata( name=module_name, title=title, description=description, sections=[ SocaUserInputSectionMetadata( name='default', params=questions ) ] ) ] ) param_registry = SocaUserInputParamRegistry(context=self, spec=spec) args = SocaUserInputArgs(context=self, param_registry=param_registry) prompt_registry = SocaPromptRegistry(context=self, param_spec=param_registry) module_meta = param_registry.get_module(module=module_name) section_prompts = [] params = param_registry.get_params(module=module_name) for param_meta in params: section_prompts.append(SocaPrompt( context=self, args=args, param=param_meta, registry=prompt_registry )) sections = [] for section_meta in module_meta.sections: sections.append(SocaUserInputSection( context=self, section=section_meta, prompts=section_prompts )) user_input_module = SocaUserInputModule( context=self, module=module_meta, sections=sections, restart_errorcodes=[errorcodes.INSTALLER_MISSING_PERMISSIONS] ) user_input_module.safe_ask() return args.build()