pai/api/client_factory.py (67 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from typing import Optional
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_sts20150401.client import Client as StsClient
from alibabacloud_tea_openapi.models import Config
from ..common.consts import Network
from ..common.logging import get_logger
from ..common.utils import http_user_agent
from ..libs.alibabacloud_aiworkspace20210204.client import Client as WorkspaceClient
from ..libs.alibabacloud_eas20210701.client import Client as EasClient
from ..libs.alibabacloud_pai_dlc20201203.client import Client as DlcClient
from ..libs.alibabacloud_pai_dsw20220101.client import Client as DswClient
from ..libs.alibabacloud_paiflow20210202.client import Client as FlowClient
from ..libs.alibabacloud_paistudio20220112.client import Client as PaiClient
from .base import ServiceName
_logger = get_logger(__name__)
DEFAULT_SERVICE_ENDPOINT_PATTERN = "{}.{}.aliyuncs.com"
class ClientFactory(object):
ClientByServiceName = {
ServiceName.PAI_DLC: DlcClient,
ServiceName.PAI_EAS: EasClient,
ServiceName.PAI_WORKSPACE: WorkspaceClient,
ServiceName.PAIFLOW: FlowClient,
ServiceName.PAI_STUDIO: PaiClient,
ServiceName.STS: StsClient,
ServiceName.PAI_DSW: DswClient,
}
@staticmethod
def _is_inner_client(acs_client):
return acs_client.get_region_id() == "center"
@classmethod
def create_client(
cls,
service_name,
region_id: str,
credential_client: CredentialClient,
network: Optional[Network] = None,
**kwargs,
):
"""Create an API client which is responsible to interacted with the Alibaba
Cloud service."""
config = Config(
region_id=region_id,
credential=credential_client,
endpoint=cls.get_endpoint(
service_name=service_name,
region_id=region_id,
network=network,
),
signature_algorithm="v2",
user_agent=http_user_agent(),
**kwargs,
)
client = cls.ClientByServiceName.get(service_name)(config)
return client
@classmethod
def get_endpoint(
cls, service_name: str, region_id: str, network: Optional[Network] = None
) -> str:
"""Get the endpoint for the service client."""
if not region_id:
raise ValueError("Please provide region_id to get the endpoint.")
if network and network != Network.PUBLIC:
if service_name == "pai-eas":
# see endpoint list provided by PAI-EAS
# https://next.api.aliyun.com/product/eas
subdomain = f"pai-eas-manage-{network.value.lower()}"
else:
subdomain = f"{service_name}-{network.value.lower()}"
else:
subdomain = service_name
return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id)