pai/session.py (407 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import json import os.path import posixpath from datetime import datetime from typing import Any, Dict, Optional, Tuple, Union import oss2 from alibabacloud_credentials.client import Client as CredentialClient from alibabacloud_credentials.exceptions import CredentialException from alibabacloud_credentials.models import Config as CredentialConfig from alibabacloud_credentials.utils import auth_constant from Tea.exceptions import TeaException from .api.api_container import ResourceAPIsContainerMixin from .api.base import ServiceName from .api.client_factory import ClientFactory from .api.workspace import WorkspaceAPI, WorkspaceConfigKeys from .common.consts import DEFAULT_CONFIG_PATH, PAI_VPC_ENDPOINT, Network from .common.logging import get_logger from .common.oss_utils import CredentialProviderWrapper, OssUriObj from .common.utils import is_domain_connectable, make_list_resource_iterator from .libs.alibabacloud_pai_dsw20220101.models import GetInstanceRequest logger = get_logger(__name__) # Environment variable that indicates where the config path is located. # If it is not provided, "$HOME/.pai/config.json" is used as the default config path. ENV_PAI_CONFIG_PATH = "PAI_CONFIG_PATH" INNER_REGION_IDS = ["center"] # Global default session used by the program. _default_session = None # Default config keys. _DEFAULT_CONFIG_KEYS = [ "region_id", "oss_bucket_name", "workspace_id", "oss_endpoint", ] def setup_default_session( access_key_id: Optional[str] = None, access_key_secret: Optional[str] = None, security_token: Optional[str] = None, region_id: Optional[str] = None, credential_config: Optional[CredentialConfig] = None, oss_bucket_name: Optional[str] = None, oss_endpoint: Optional[str] = None, workspace_id: Optional[Union[str, int]] = None, network: Optional[Union[str, Network]] = None, **kwargs, ) -> "Session": """Set up the default session used in the program. The function construct a session that used for communicating with PAI service, and set it as the global default instance. Args: access_key_id (str): The access key ID used to access the Alibaba Cloud. access_key_secret (str): The access key secret used to access the Alibaba Cloud. security_token (str, optional): The security token used to access the Alibaba Cloud. credential_config (:class:`alibabacloud_credentials.models.Config`, optional): The credential config used to access the Alibaba Cloud. region_id (str): The ID of the Alibaba Cloud region where the service is located. workspace_id (str, optional): ID of the workspace used in the default session. oss_bucket_name (str, optional): The name of the OSS bucket used in the session. oss_endpoint (str, optional): The endpoint for the OSS bucket. network (Union[str, Network], optional): The network to use for the connection. supported values are "VPC" and "PUBLIC". If provided, this value will be used as-is. Otherwise, the code will first check for an environment variable PAI_NETWORK_TYPE. If that is not set and the VPC endpoint is available, it will be used. As a last resort, if all else fails, the PUBLIC endpoint will be used. **kwargs: Returns: :class:`pai.session.Session`: Initialized default session. """ if (access_key_id and access_key_secret) and credential_config: raise ValueError("Please provide either access_key or credential_config.") elif not credential_config and (access_key_id and access_key_secret): # use explicit credential if security_token: credential_config = CredentialConfig( access_key_id=access_key_id, access_key_secret=access_key_secret, security_token=security_token, type=auth_constant.STS, ) else: credential_config = CredentialConfig( access_key_id=access_key_id, access_key_secret=access_key_secret, type=auth_constant.ACCESS_KEY, ) # override the config from default session default_session = get_default_session() if default_session: region_id = region_id or default_session.region_id workspace_id = workspace_id or default_session.workspace_id oss_bucket_name = oss_bucket_name or default_session.oss_bucket_name oss_endpoint = oss_endpoint or default_session.oss_endpoint credential_config = credential_config or default_session.credential_config network = network or default_session.network session = Session( region_id=region_id, credential_config=credential_config, oss_bucket_name=oss_bucket_name, oss_endpoint=oss_endpoint, workspace_id=workspace_id, network=network, **kwargs, ) global _default_session _default_session = session return session def get_default_session() -> "Session": """Get the default session used by the program. If the global default session is set, the function will try to initialize a session from config file. Returns: :class:`pai.session.Session`: The default session. """ global _default_session if not _default_session: config = load_default_config_file() if config: _default_session = Session(**config) else: _default_session = _init_default_session_from_env() return _default_session def _init_default_session_from_env() -> Optional["Session"]: credential_client = Session._get_default_credential_client() if not credential_client: logger.debug("Not found credential from default credential provider chain.") return # legacy region id env var in DSW region_id = os.getenv("dsw_region") region_id = os.getenv("REGION", region_id) if not region_id: logger.debug( "No region id found(env var: REGION or dsw_region), skip init default session" ) return dsw_instance_id = os.getenv("DSW_INSTANCE_ID") if not dsw_instance_id: logger.debug( "No dsw instance id (env var: DSW_INSTANCE_ID) found, skip init default session" ) return workspace_id = os.getenv("PAI_AI_WORKSPACE_ID") workspace_id = os.getenv("PAI_WORKSPACE_ID", workspace_id) network = ( Network.VPC if is_domain_connectable( PAI_VPC_ENDPOINT.format(region_id), timeout=1, ) else Network.PUBLIC ) if dsw_instance_id and not workspace_id: logger.debug("Getting workspace id by dsw instance id: %s", dsw_instance_id) workspace_id = Session._get_workspace_id_by_dsw_instance_id( dsw_instance_id=dsw_instance_id, cred=credential_client, region_id=region_id, network=network, ) if not workspace_id: logger.warning( "Failed to get workspace id by dsw instance id: %s", dsw_instance_id ) return bucket_name, oss_endpoint = Session.get_default_oss_storage( workspace_id, credential_client, region_id, network ) if not bucket_name: logger.warning( "Default OSS storage is not configured for the workspace: %s", workspace_id ) sess = Session( region_id=region_id, workspace_id=workspace_id, credential_config=None, oss_bucket_name=bucket_name, oss_endpoint=oss_endpoint, network=network, ) return sess def load_default_config_file() -> Optional[Dict[str, Any]]: """Read config file""" config_path = DEFAULT_CONFIG_PATH if not os.path.exists(config_path): return with open(config_path, "r") as f: config = json.load(f) config = { key: value for key, value in config.items() if key in _DEFAULT_CONFIG_KEYS } # for backward compatibility, try to read credential from config file. if "access_key_id" in config and "access_key_secret" in config: config["credential_config"] = CredentialConfig( access_key_id=config["access_key_id"], access_key_secret=config["access_key_secret"], type=auth_constant.ACCESS_KEY, ) return config class Session(ResourceAPIsContainerMixin): """A class responsible for communicating with PAI services.""" def __init__( self, region_id: str, workspace_id: Optional[str] = None, credential_config: Optional[CredentialConfig] = None, oss_bucket_name: Optional[str] = None, oss_endpoint: Optional[str] = None, **kwargs, ): """PAI Session Initializer. Args: credential_config (:class:`alibabacloud_credentials.models.Config`, optional): The credential config used to access the Alibaba Cloud. region_id (str): The ID of the Alibaba Cloud region where the service is located. workspace_id (str, optional): ID of the workspace used in the default session. oss_bucket_name (str, optional): The name of the OSS bucket used in the session. oss_endpoint (str, optional): The endpoint for the OSS bucket. """ if not region_id: raise ValueError("Region ID must be provided.") self._credential_config = credential_config self._region_id = region_id self._workspace_id = str(workspace_id) self._oss_bucket_name = oss_bucket_name self._oss_endpoint = oss_endpoint header = kwargs.pop("header", None) network = kwargs.pop("network", None) runtime = kwargs.pop("runtime", None) if kwargs: logger.warning( "Unused arguments found in session initialization: %s", kwargs ) super(Session, self).__init__(header=header, network=network, runtime=runtime) @property def region_id(self) -> str: return self._region_id @property def is_inner(self) -> bool: return self._region_id in INNER_REGION_IDS @property def oss_bucket_name(self) -> str: return self._oss_bucket_name @property def oss_endpoint(self) -> str: return self._oss_endpoint @property def credential_config(self) -> CredentialConfig: return self._credential_config @property def workspace_name(self): if hasattr(self, "_workspace_name") and self._workspace_name: return self._workspace_name if not self._workspace_id: raise ValueError("Workspace id is not set.") workspace_api_obj = self.workspace_api.get(workspace_id=self._workspace_id) self._workspace_name = workspace_api_obj["WorkspaceName"] return self._workspace_name @property def provider(self) -> str: caller_identity = self._acs_sts_client.get_caller_identity().body return caller_identity.account_id @property def workspace_id(self) -> str: """ID of the workspace used by the session.""" return self._workspace_id @property def console_uri(self) -> str: """The web console URI for PAI service.""" if self.is_inner: return "https://pai-next.alibaba-inc.com" else: return "https://pai.console.aliyun.com/console" def _init_oss_config( self, ): """Initialize a OssConfig instance.""" if not self._oss_bucket_name: # If OSS bucket name is not provided, use the default OSS storage URI # that is configured for the workspace. default_oss_uri = self.workspace_api.get_default_storage_uri( self.workspace_id ) if not default_oss_uri: raise RuntimeError( "No default OSS URI is configured for the workspace." ) oss_uri_obj = OssUriObj(default_oss_uri) self._oss_bucket_name = oss_uri_obj.bucket_name if not self._oss_endpoint: self._oss_endpoint = self._get_default_oss_endpoint() def _get_oss_auth(self): auth = oss2.ProviderAuth( credentials_provider=CredentialProviderWrapper( config=self._credential_config, ) ) return auth @property def oss_bucket(self): """A OSS2 bucket instance used by the session.""" if not self._oss_bucket_name or not self._oss_endpoint: self._init_oss_config() oss_bucket = oss2.Bucket( auth=self._get_oss_auth(), endpoint=self._oss_endpoint, bucket_name=self._oss_bucket_name, ) return oss_bucket def save_config(self, config_path=None): """Save the configuration of the session to a local file.""" attrs = {key.lstrip("_"): value for key, value in vars(self).items()} config = { key: value for key, value in attrs.items() if key in _DEFAULT_CONFIG_KEYS and value is not None } config_path = config_path or DEFAULT_CONFIG_PATH os.makedirs(os.path.dirname(config_path), exist_ok=True) with open(config_path, "w") as f: f.write(json.dumps(config, indent=4)) logger.info("Write PAI config succeed: config_path=%s" % config_path) def patch_oss_endpoint(self, oss_uri: str): oss_uri_obj = OssUriObj(oss_uri) if oss_uri_obj.endpoint: return oss_uri # patch endpoint using current OSS bucket endpoint. endpoint = self.oss_bucket.endpoint if endpoint.startswith("http://"): endpoint = endpoint.lstrip("http://") elif endpoint.startswith("https://"): endpoint = endpoint.lstrip("https://") return "oss://{bucket_name}.{endpoint}/{key}".format( bucket_name=oss_uri_obj.bucket_name, endpoint=endpoint, key=oss_uri_obj.object_key, ) def _get_default_oss_endpoint(self) -> str: """Returns a default OSS endpoint.""" # OSS Endpoint document: # https://help.aliyun.com/document_detail/31837.html internet_endpoint = "oss-{}.aliyuncs.com".format(self.region_id) internal_endpoint = "oss-{}-internal.aliyuncs.com".format(self.region_id) return ( internet_endpoint if is_domain_connectable(internal_endpoint) else internet_endpoint ) def get_oss_bucket(self, bucket_name: str, endpoint: str = None) -> oss2.Bucket: """Get a OSS bucket using the credentials of the session. Args: bucket_name (str): The name of the bucket. endpoint (str): Endpoint of the bucket. Returns: :class:`oss2.Bucket`: A OSS bucket instance. """ endpoint = endpoint or self._oss_endpoint or self._get_default_oss_endpoint() oss_bucket = oss2.Bucket( auth=self._get_oss_auth(), endpoint=endpoint, bucket_name=bucket_name, ) return oss_bucket @classmethod def get_storage_path_by_category( cls, category: str, dir_name: Optional[str] = None ) -> str: """Get an OSS storage path for the resource. Args: category (str): The category of the resource. dir_name (str, optional): The directory name of the resource. Returns: str: A OSS storage path. """ dir_name = dir_name or datetime.now().strftime("%Y%m%d_%H%M%S_%f") storage_path = posixpath.join("pai", category, dir_name).strip() if not storage_path.endswith("/"): storage_path += "/" return storage_path def is_supported_training_instance(self, instance_type: str) -> bool: """Check if the instance type is supported for training.""" instance_generator = make_list_resource_iterator(self.job_api.list_ecs_specs) machine_spec = next( ( item for item in instance_generator if item["InstanceType"] == instance_type ), None, ) return bool(machine_spec) def is_gpu_training_instance(self, instance_type: str) -> bool: """Check if the instance type is GPU instance for training.""" instance_generator = make_list_resource_iterator(self.job_api.list_ecs_specs) machine_spec = next( ( item for item in instance_generator if item["InstanceType"] == instance_type ), None, ) if not machine_spec: raise ValueError( f"Instance type {instance_type} is not supported for training job. " "Please provide a supported instance type." ) return machine_spec["AcceleratorType"] == "GPU" def is_supported_inference_instance(self, instance_type: str) -> bool: """Check if the instance type is supported for inference.""" res = self.service_api.describe_machine()["InstanceMetas"] spec = next( (item for item in res if item["InstanceType"] == instance_type), None ) return bool(spec) def is_gpu_inference_instance(self, instance_type: str) -> bool: """Check if the instance type is GPU instance for inference.""" res = self.service_api.describe_machine()["InstanceMetas"] spec = next( (item for item in res if item["InstanceType"] == instance_type), None ) if not spec: raise ValueError( f"Instance type {instance_type} is not supported for deploying. " "Please provide a supported instance type." ) return bool(spec["GPU"]) @staticmethod def get_default_oss_storage( workspace_id: str, cred: CredentialClient, region_id: str, network: Network ) -> Tuple[Optional[str], Optional[str]]: acs_ws_client = ClientFactory.create_client( service_name=ServiceName.PAI_WORKSPACE, credential_client=cred, region_id=region_id, network=network, ) workspace_api = WorkspaceAPI( acs_client=acs_ws_client, ) resp = workspace_api.list_configs( workspace_id=workspace_id, config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI, ) oss_storage_uri = next( ( item["ConfigValue"] for item in resp["Configs"] if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI ), None, ) # Default OSS storage uri is not set. if not oss_storage_uri: return None, None uri_obj = OssUriObj(oss_storage_uri) if network == Network.VPC: endpoint = "oss-{}-internal.aliyuncs.com".format(region_id) else: endpoint = "oss-{}.aliyuncs.com".format(region_id) return uri_obj.bucket_name, endpoint @staticmethod def _get_default_credential_client() -> Optional[CredentialClient]: try: # Initialize the credential client with default credential chain. # see: https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c return CredentialClient() except CredentialException: return @staticmethod def _get_workspace_id_by_dsw_instance_id( dsw_instance_id: str, cred: CredentialClient, region_id: str, network: Network ) -> Optional[str]: """Get workspace id by dsw instance id""" dsw_client = ClientFactory.create_client( service_name=ServiceName.PAI_DSW, credential_client=cred, region_id=region_id, network=network, ) try: resp = dsw_client.get_instance( dsw_instance_id, request=GetInstanceRequest() ) return resp.body.workspace_id except TeaException as e: logger.warning("Failed to get instance info by dsw instance id: %s", e) return