pai/toolkit/helper/utils.py (321 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import locale import os import re from typing import Any, Dict, List, Optional, Tuple import oss2 from alibabacloud_credentials.client import Client as CredentialClient from alibabacloud_credentials.models import Config as CredentialConfig from alibabacloud_sts20150401.client import Client from alibabacloud_sts20150401.models import ( GetCallerIdentityResponseBody as CallerIdentity, ) from alibabacloud_tea_openapi import models as open_api_models from oss2.models import BucketInfo, SimplifiedBucketInfo from prompt_toolkit import Application from prompt_toolkit.key_binding import KeyBindings, merge_key_bindings from prompt_toolkit.key_binding.defaults import load_key_bindings from prompt_toolkit.layout import HSplit, Layout from prompt_toolkit.shortcuts import confirm as prompt_confirm from prompt_toolkit.widgets import Label, RadioList from ...api.base import ServiceName from ...api.client_factory import ClientFactory from ...api.workspace import WorkspaceAPI, WorkspaceConfigKeys from ...common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network from ...common.logging import get_logger from ...common.oss_utils import CredentialProviderWrapper, OssUriObj from ...common.utils import is_domain_connectable, make_list_resource_iterator from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient from ...libs.alibabacloud_pai_dsw20220101.models import GetInstanceRequest from ...session import Session logger = get_logger(__name__) locale_code, _ = locale.getdefaultlocale() OSS_NAME_PATTERN = re.compile(pattern="^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$") ZH_CN_LOCAL = "zh_CN" # RoleARN pattern for AssumedRole CallerIdentity ASSUMED_ROLE_ARN_PATTERN = re.compile(r"acs:ram::\d+:assumed-role/([^/]+)/.*") # DSW Notebook Default Role Name: PAI_DSW_DEFAULT_ROLE_NAME = "aliyunpaidswdefaultrole" DEFAULT_PRODUCT_RAM_ROLE_NAMES = [ "AliyunODPSPAIDefaultRole", "AliyunPAIAccessingOSSRole", "AliyunPAIDLCAccessingOSSRole", "AliyunPAIDLCDefaultRole", ] class WorkspaceRoles(object): """Workspace roles.""" AlgoDeveloper = "PAI.AlgoDeveloper" WorkspaceAdmin = "PAI.WorkspaceAdmin" WorkspaceOwner = "PAI.WorkspaceOwner" LabelManager = "PAI.LabelManager" @classmethod def recommend_roles(cls): """Recommend roles for user to use.""" return [ cls.AlgoDeveloper, cls.WorkspaceAdmin, cls.WorkspaceOwner, ] class CallerIdentityType(object): # Document: https://help.aliyun.com/document_detail/371868.html # - Account: an Alibaba Cloud account Account = "Account" # - RamUser: a RAM user RamUser = "RAMUser" # - AssumedRoleUser: a RAM role AssumedRoleUser = "AssumedRoleUser" class UserProfile(object): _credential_client = None def __init__( self, credential_config: CredentialConfig, region_id: str, ): self.region_id = region_id self.credential_config = credential_config if DEFAULT_NETWORK_TYPE: self.network = Network.from_string(DEFAULT_NETWORK_TYPE) else: self.network = ( Network.VPC if is_domain_connectable(PAI_VPC_ENDPOINT.format(self.region_id)) else Network.PUBLIC ) self._caller_identify = self._get_caller_identity() def _get_credential_client(self): if self._credential_client: return self._credential_client self._credential_client = CredentialClient(self.credential_config) return self._credential_client def get_access_key_id(self): return self._get_credential_client().get_access_key_id() def get_access_key_secret(self): return self._get_credential_client().get_access_key_secret() def get_security_token(self): return self._get_credential_client().get_security_token() def _get_caller_identity(self) -> CallerIdentity: return ( Client( config=open_api_models.Config( credential=self._get_credential_client(), region_id=self.region_id, network=( None if self.network == Network.PUBLIC else self.network.value.lower() ), ) ) .get_caller_identity() .body ) def is_dsw_default_role(self) -> bool: if self._caller_identify.identity_type != CallerIdentityType.AssumedRoleUser: return False m = ASSUMED_ROLE_ARN_PATTERN.match(self._caller_identify.arn) return m and m.group(1).lower() == PAI_DSW_DEFAULT_ROLE_NAME def get_acs_dsw_client(self) -> DswClient: return ClientFactory.create_client( service_name=ServiceName.PAI_DSW, credential_client=self._get_credential_client(), region_id=self.region_id, network=self.network, ) def get_instance_info(self, instance_id: str) -> Dict[str, Any]: dsw_client = self.get_acs_dsw_client() return dsw_client.get_instance( instance_id, request=GetInstanceRequest() ).body.to_map() def get_credential(self): return self._credential_client.get_access_key_id() @property def is_ram_user(self) -> bool: return self._caller_identify.identity_type == CallerIdentityType.RamUser @property def is_account(self) -> bool: return self._caller_identify.identity_type == CallerIdentityType.Account @property def account_id(self): """Return Alibaba Cloud account ID of the current user profile""" return self._caller_identify.account_id @property def user_id(self): """Return the Alibaba Cloud user ID of the current user profile""" return self._caller_identify.user_id @property def identify_type(self): return self._caller_identify.identity_type def get_default_oss_endpoint(self): return "https://oss-{}.aliyuncs.com".format(self.region_id) def list_oss_buckets(self, prefix: str = "") -> List[SimplifiedBucketInfo]: buckets: List[SimplifiedBucketInfo] = [] service = oss2.Service( auth=oss2.ProviderAuth( credentials_provider=CredentialProviderWrapper( config=self.credential_config, ), ), endpoint=self.get_default_oss_endpoint(), ) marker = "" while True: res: oss2.models.ListBucketsResult = service.list_buckets( prefix=prefix, marker=marker ) buckets.extend( [b for b in res.buckets if self.region_id in b.location] or [] ) if not res.is_truncated: break else: marker = res.next_marker return buckets def get_bucket_info(self, bucket_name) -> BucketInfo: auth = oss2.ProviderAuth( credentials_provider=CredentialProviderWrapper( config=self.credential_config, ), ) bucket = oss2.Bucket( auth, self.get_default_oss_endpoint(), bucket_name=bucket_name ) bucket_info = bucket.get_bucket_info() return bucket_info def create_oss_bucket(self, bucket_name): bucket = oss2.Bucket( bucket_name=bucket_name, auth=oss2.ProviderAuth( credentials_provider=CredentialProviderWrapper( config=self.credential_config, ), ), endpoint=self.get_default_oss_endpoint(), ) bucket.create_bucket() def get_production_authorizations(self): workspace_api = self.get_workspace_api() res = workspace_api.list_product_authorizations( ram_role_names=DEFAULT_PRODUCT_RAM_ROLE_NAMES ) return res["AuthorizationDetails"] def get_workspace_api(self) -> WorkspaceAPI: acs_ws_client = ClientFactory.create_client( service_name=ServiceName.PAI_WORKSPACE, credential_client=self._get_credential_client(), region_id=self.region_id, network=self.network, ) return WorkspaceAPI( acs_client=acs_ws_client, ) def get_default_oss_storage_uri( self, workspace_id: str ) -> Tuple[Optional[str], Optional[str]]: bucket_name, endpoint = Session.get_default_oss_storage( workspace_id=workspace_id, cred=self._get_credential_client(), region_id=self.region_id, network=self.network, ) return "oss://{}/".format(bucket_name), endpoint def set_default_oss_storage( self, workspace_id, bucket_name: str, intranet_endpoint: str ): workspace_api = self.get_workspace_api() oss_uri = "oss://{}.{}/".format(bucket_name, intranet_endpoint) configs = {WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI: oss_uri} workspace_api.update_configs(workspace_id, configs=configs) def get_roles_in_workspace( self, workspace_id, user_id: Optional[str] = None ) -> List[str]: workspace_api = self.get_workspace_api() user_id = user_id or self.user_id member_info = next( ( mem for mem in make_list_resource_iterator( workspace_api.list_members, workspace_id=workspace_id, ) if mem["UserId"] == user_id ), None, ) # If user has PAIFullAccess policy, 'member_info' may be None. return member_info["Roles"] if member_info else [] def has_permission_edit_config(self, workspace_id: str) -> bool: """Return True if the current user has permission to edit workspace config. Only members with the role of WorkspaceAdmin or WorkspaceOwner can edit workspace config. """ roles = self.get_roles_in_workspace(workspace_id) return any( ( r in roles for r in [WorkspaceRoles.WorkspaceAdmin, WorkspaceRoles.WorkspaceOwner] ) ) def localized_text(en_text: str, cn_text: str = None): if locale_code == ZH_CN_LOCAL: return cn_text or en_text else: return en_text def mask_secret(secret, mask_count=4): masked = ( secret[:mask_count] + (len(secret) - mask_count * 2) * "*" + secret[-mask_count:] ) return masked def mask_and_trim(secret, max_size=20): masked = mask_secret(secret, mask_count=8) if len(masked) > max_size: masked = masked[:max_size] + "..." return masked def radio_list_prompt( title: str = "", values=None, cancel_value=None, style=None, async_: bool = False, **kwargs, ): # Create the radio list radio_list = RadioList(values) # Remove the enter key binding so that we can augment it radio_list.control.key_bindings.remove("enter") bindings = KeyBindings() # Replace the enter key binding to select the value and also submit it @bindings.add("enter") def exit_with_value(event): """ Pressing Enter will exit the user interface, returning the highlighted value. """ radio_list._handle_enter() event.app.exit(result=radio_list.current_value) @bindings.add("c-c") def backup_exit_with_value(event): """ Pressing Ctrl-C will exit the user interface with the cancel_value. """ event.app.exit(result=cancel_value) # Create and run the mini inline application application = Application( layout=Layout(HSplit([Label(title), radio_list])), key_bindings=merge_key_bindings([load_key_bindings(), bindings]), mouse_support=True, style=style, **kwargs, ) if async_: return application.run_async() else: return application.run() def confirm(message: str = "Confirm?", suffix: str = " (y/n, default: y)"): # Input enter key returns an empty string, we assume enter is 'YES'. res = prompt_confirm(message, suffix) yes = True if isinstance(res, str) and res.strip() == "" else res return yes def not_empty(text: str) -> bool: return bool(text.strip()) def print_highlight(msg: str): print(ColorEscape.green(msg)) def print_warning(msg: str): print(ColorEscape.red(msg)) def validate_bucket_name(name: str) -> bool: return bool(OSS_NAME_PATTERN.match(name)) class ColorEscape(object): """ A utility class to wrap a string with color escape code. """ _black = "\u001b[30m" _red = "\u001b[31m" _green = "\u001b[32m" _yellow = "\u001b[33m" _blue = "\u001b[34m" _magenta = "\u001b[35m" _cyan = "\u001b[36m" _white = "\u001b[37m" _default = "\u001b[39m" _reset = "\u001b[0m" @classmethod def green(cls, msg: str) -> str: return cls._format(msg, cls._green) @classmethod def red(cls, msg: str) -> str: return cls._format(msg, cls._red) @classmethod def _format(cls, msg: str, code: str) -> str: if os.environ.get("NO_COLOR"): # See https://no-color.org/ return msg return code + msg + cls._reset