src/sagemaker_core/tools/templates.py (208 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Templates for generating resources.""" RESOURCE_CLASS_TEMPLATE = """ class {class_name}: {data_class_members} {init_method} {class_methods} {object_methods} """ RESOURCE_METHOD_EXCEPTION_DOCSTRING = """ Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: ``` try: # AWS service call here except botocore.exceptions.ClientError as e: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ```""" CREATE_METHOD_TEMPLATE = """ @classmethod @populate_inputs_decorator @Base.add_validate_call def create( cls, {create_args} session: Optional[Session] = None, region: Optional[str] = None, ) -> Optional["{resource_name}"]: {docstring} logger.info("Creating {resource_lower} resource.") client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}') operation_input_args = {{ {operation_input_args} }} operation_input_args = Base.populate_chained_attributes(resource_name='{resource_name}', operation_input_args=operation_input_args) logger.debug(f"Input request: {{operation_input_args}}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") # create the resource response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}") return cls.get({get_args}, session=session, region=region) """ CREATE_METHOD_TEMPLATE_WITHOUT_DEFAULTS = """ @classmethod @Base.add_validate_call def create( cls, {create_args} session: Optional[Session] = None, region: Optional[str] = None, ) -> Optional["{resource_name}"]: {docstring} logger.info("Creating {resource_lower} resource.") client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}') operation_input_args = {{ {operation_input_args} }} operation_input_args = Base.populate_chained_attributes(resource_name='{resource_name}', operation_input_args=operation_input_args) logger.debug(f"Input request: {{operation_input_args}}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") # create the resource response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}") return cls.get({get_args}, session=session, region=region) """ IMPORT_METHOD_TEMPLATE = """ @classmethod @Base.add_validate_call def load( cls, {import_args} session: Optional[Session] = None, region: Optional[str] = None, ) -> Optional["{resource_name}"]: {docstring} logger.info(f"Importing {resource_lower} resource.") client = SageMakerClient(session=session, region_name=region, service_name='{service_name}').client operation_input_args = {{ {operation_input_args} }} logger.debug(f"Input request: {{operation_input_args}}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") # import the resource response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}") return cls.get({get_args}, session=session, region=region) """ GET_NAME_METHOD_TEMPLATE = """ def get_name(self) -> str: attributes = vars(self) resource_name = '{resource_lower}_name' resource_name_split = resource_name.split('_') attribute_name_candidates = [] l = len(resource_name_split) for i in range(0, l): attribute_name_candidates.append("_".join(resource_name_split[i:l])) for attribute, value in attributes.items(): if attribute == 'name' or attribute in attribute_name_candidates: return value logger.error("Name attribute not found for object {resource_lower}") return None """ UPDATE_METHOD_TEMPLATE = """ @populate_inputs_decorator @Base.add_validate_call def update( self, {update_args} ) -> Optional["{resource_name}"]: {docstring} logger.info("Updating {resource_lower} resource.") client = Base.get_sagemaker_client() operation_input_args = {{ {operation_input_args} }} logger.debug(f"Input request: {{operation_input_args}}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") # create the resource response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}") self.refresh() return self """ UPDATE_METHOD_TEMPLATE_WITHOUT_DECORATOR = """ @Base.add_validate_call def update( self, {update_args} ) -> Optional["{resource_name}"]: {docstring} logger.info("Updating {resource_lower} resource.") client = Base.get_sagemaker_client() operation_input_args = {{ {operation_input_args} }} logger.debug(f"Input request: {{operation_input_args}}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") # create the resource response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}") self.refresh() return self """ POPULATE_DEFAULTS_DECORATOR_TEMPLATE = """ def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = \\ {config_schema_for_resource} return create_func(*args, **Base.get_updated_kwargs_with_configured_attributes(config_schema_for_resource, "{resource_name}", **kwargs)) return wrapper """ GET_METHOD_TEMPLATE = """ @classmethod @Base.add_validate_call def get( cls, {describe_args} session: Optional[Session] = None, region: Optional[str] = None, ) -> Optional["{resource_name}"]: {docstring} operation_input_args = {{ {operation_input_args} }} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}') response = client.{operation}(**operation_input_args) logger.debug(response) # deserialize the response transformed_response = transform(response, '{describe_operation_output_shape}') {resource_lower} = cls(**transformed_response) return {resource_lower} """ REFRESH_METHOD_TEMPLATE = """ @Base.add_validate_call def refresh( self, {refresh_args} ) -> Optional["{resource_name}"]: {docstring} operation_input_args = {{ {operation_input_args} }} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") client = Base.get_sagemaker_client() response = client.{operation}(**operation_input_args) # deserialize response and update self transform(response, '{describe_operation_output_shape}', self) return self """ FAILED_STATUS_ERROR_TEMPLATE = """ if "failed" in current_status.lower(): raise FailedStatusError(resource_type="{resource_name}", status=current_status, reason={reason}) """ INIT_WAIT_LOGS_TEMPLATE = """ instance_count = {get_instance_count} if logs: multi_stream_logger = MultiLogStreamHandler( log_group_name=f"/aws/sagemaker/{job_type}s", log_stream_name_prefix=self.get_name(), expected_stream_count=instance_count ) """ PRINT_WAIT_LOGS = """ if logs and multi_stream_logger.ready(): stream_log_events = multi_stream_logger.get_latest_log_events() for stream_id, event in stream_log_events: logger.info(f"{stream_id}:\\n{event['message']}") """ WAIT_METHOD_TEMPLATE = ''' @Base.add_validate_call def wait( self, poll: int = 5, timeout: Optional[int] = None, {logs_arg} ) -> None: """ Wait for a {resource_name} resource. Parameters: poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. {logs_arg_doc} Raises: TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. """ terminal_states = {terminal_resource_states} start_time = time.time() progress = Progress(SpinnerColumn("bouncingBar"), TextColumn("{{task.description}}"), TimeElapsedColumn(), ) progress.add_task("Waiting for {resource_name}...") status = Status("Current status:") {init_wait_logs} with Live( Panel( Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value ) ), transient=True ): while True: self.refresh() current_status = self{status_key_path} status.update(f"Current status: [bold]{{current_status}}") {print_wait_logs} if current_status in terminal_states: logger.info(f"Final Resource Status: [bold]{{current_status}}") {failed_error_block} return if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status) time.sleep(poll) ''' WAIT_FOR_STATUS_METHOD_TEMPLATE = ''' @Base.add_validate_call def wait_for_status( self, target_status: Literal{resource_states}, poll: int = 5, timeout: Optional[int] = None ) -> None: """ Wait for a {resource_name} resource to reach certain status. Parameters: target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. Raises: TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. """ start_time = time.time() progress = Progress(SpinnerColumn("bouncingBar"), TextColumn("{{task.description}}"), TimeElapsedColumn(), ) progress.add_task(f"Waiting for {resource_name} to reach [bold]{{target_status}} status...") status = Status("Current status:") with Live( Panel( Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value ) ), transient=True ): while True: self.refresh() current_status = self{status_key_path} status.update(f"Current status: [bold]{{current_status}}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{{current_status}}") return {failed_error_block} if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status) time.sleep(poll) ''' WAIT_FOR_DELETE_METHOD_TEMPLATE = ''' @Base.add_validate_call def wait_for_delete( self, poll: int = 5, timeout: Optional[int] = None, ) -> None: """ Wait for a {resource_name} resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: ``` try: # AWS service call here except botocore.exceptions.ClientError as e: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` TimeoutExceededError: If the resource does not reach a terminal state before the timeout. DeleteFailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. """ start_time = time.time() progress = Progress(SpinnerColumn("bouncingBar"), TextColumn("{{task.description}}"), TimeElapsedColumn(), ) progress.add_task("Waiting for {resource_name} to be deleted...") status = Status("Current status:") with Live(Panel(Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value))): while True: try: self.refresh() current_status = self{status_key_path} status.update(f"Current status: [bold]{{current_status}}") {delete_failed_error_block} {deleted_status_check} if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError(resouce_type="{resource_name}", status=current_status) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] if "ResourceNotFound" in error_code or "ValidationException" in error_code: logger.info("Resource was not found. It may have been deleted.") return raise e time.sleep(poll) ''' DELETE_FAILED_STATUS_CHECK = """ if "delete_failed" in current_status.lower() or "deletefailed" in current_status.lower(): raise DeleteFailedStatusError(resource_type="{resource_name}", reason={reason}) """ DELETED_STATUS_CHECK = """ if current_status.lower() == "deleted": print("Resource was deleted.") return """ DELETE_METHOD_TEMPLATE = """ @Base.add_validate_call def delete( self, {delete_args} ) -> None: {docstring} client = Base.get_sagemaker_client() operation_input_args = {{ {operation_input_args} }} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") client.{operation}(**operation_input_args) logger.info(f"Deleting {{self.__class__.__name__}} - {{self.get_name()}}") """ STOP_METHOD_TEMPLATE = """ @Base.add_validate_call def stop(self) -> None: {docstring} client = SageMakerClient().client operation_input_args = {{ {operation_input_args} }} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") client.{operation}(**operation_input_args) logger.info(f"Stopping {{self.__class__.__name__}} - {{self.get_name()}}") """ GET_ALL_METHOD_WITH_ARGS_TEMPLATE = """ @classmethod @Base.add_validate_call def get_all( cls, {get_all_args} session: Optional[Session] = None, region: Optional[str] = None, ) -> ResourceIterator["{resource}"]: {docstring} client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") operation_input_args = {{ {operation_input_args} }} {custom_key_mapping} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") return ResourceIterator( {resource_iterator_args} ) """ GET_ALL_METHOD_NO_ARGS_TEMPLATE = ''' @classmethod @Base.add_validate_call def get_all( cls, session: Optional[Session] = None, region: Optional[str] = None, ) -> ResourceIterator["{resource}"]: """ Get all {resource} resources. Parameters: session: Boto3 session. region: Region name. Returns: Iterator for listed {resource} resources. """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") {custom_key_mapping} return ResourceIterator( {resource_iterator_args} ) ''' GENERIC_METHOD_TEMPLATE = """ {decorator} @Base.add_validate_call def {method_name}( {method_args} ) -> {return_type}: {docstring} {serialize_operation_input} {initialize_client} {call_operation_api} {deserialize_response} """ SERIALIZE_INPUT_TEMPLATE = """ operation_input_args = {{ {operation_input_args} }} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}")""" INITIALIZE_CLIENT_TEMPLATE = """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name='{service_name}')""" CALL_OPERATION_API_TEMPLATE = """ logger.debug(f"Calling {operation} API") response = client.{operation}(**operation_input_args) logger.debug(f"Response: {{response}}")""" CALL_OPERATION_API_NO_ARG_TEMPLATE = """ logger.debug(f"Calling {operation} API") response = client.{operation}() logger.debug(f"Response: {{response}}")""" DESERIALIZE_RESPONSE_TEMPLATE = """ transformed_response = transform(response, '{operation_output_shape}') return {return_type_conversion}(**transformed_response)""" DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE = """ return list(response.values())[0]""" RETURN_ITERATOR_TEMPLATE = """ return ResourceIterator( {resource_iterator_args} )""" DESERIALIZE_INPUT_AND_RESPONSE_TO_CLS_TEMPLATE = """ transformed_response = transform(response, '{operation_output_shape}') return cls(**operation_input_args, **transformed_response)""" RESOURCE_BASE_CLASS_TEMPLATE = """ class Base(BaseModel): model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid") @classmethod def get_sagemaker_client(cls, session = None, region_name = None, service_name = 'sagemaker'): return SageMakerClient(session=session, region_name=region_name).get_client(service_name=service_name) @staticmethod def get_updated_kwargs_with_configured_attributes( config_schema_for_resource: dict, resource_name: str, **kwargs ): try: for configurable_attribute in config_schema_for_resource: if kwargs.get(configurable_attribute) is None: resource_defaults = load_default_configs_for_resource_name( resource_name=resource_name ) global_defaults = load_default_configs_for_resource_name( resource_name="GlobalDefaults" ) if config_value := get_config_value( configurable_attribute, resource_defaults, global_defaults ): resource_name = snake_to_pascal(configurable_attribute) class_object = getattr(shapes, resource_name, None) or globals().get(resource_name) kwargs[configurable_attribute] = class_object(**config_value) except BaseException as e: logger.debug("Could not load Default Configs. Continuing.", exc_info=True) # Continue with existing kwargs if no default configs found return kwargs @staticmethod def populate_chained_attributes(resource_name: str, operation_input_args: Union[dict, object]): resource_name_in_snake_case = pascal_to_snake(resource_name) updated_args = vars(operation_input_args) if type(operation_input_args) == object else operation_input_args unassigned_args = [] keys = operation_input_args.keys() for arg in keys: value = operation_input_args.get(arg) arg_snake = pascal_to_snake(arg) if value == Unassigned() : unassigned_args.append(arg) elif value == None or not value: continue elif ( arg_snake.endswith("name") and arg_snake[: -len("_name")] != resource_name_in_snake_case and arg_snake != "name" ): if value and value != Unassigned() and type(value) != str: updated_args[arg] = value.get_name() elif isinstance(value, list) and is_primitive_list(value): continue elif isinstance(value, list) and value != []: updated_args[arg] = [ Base._get_chained_attribute(list_item) for list_item in value ] elif is_not_primitive(value) and is_not_str_dict(value) and type(value) == object: updated_args[arg] = Base._get_chained_attribute(item_value=value) for unassigned_arg in unassigned_args: del updated_args[unassigned_arg] return updated_args @staticmethod def _get_chained_attribute(item_value: Any): resource_name = type(item_value).__name__ class_object = globals().get(resource_name) or getattr(shapes, resource_name, None) if class_object is None: return item_value return class_object(**Base.populate_chained_attributes( resource_name=resource_name, operation_input_args=vars(item_value) )) @staticmethod def add_validate_call(func): @functools.wraps(func) def wrapper(*args, **kwargs): config = dict(arbitrary_types_allowed=True) return validate_call(config=config)(func)(*args, **kwargs) return wrapper """ SHAPE_BASE_CLASS_TEMPLATE = """ class {class_name}: model_config = ConfigDict(protected_namespaces=(), validate_assignment=True, extra="forbid") """ SHAPE_CLASS_TEMPLATE = ''' class {class_name}: """ {docstring} """ {data_class_members} ''' RESOURCE_METHOD_EXCEPTION_DOCSTRING = """ Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: ``` try: # AWS service call here except botocore.exceptions.ClientError as e: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ```"""