pai/api/training_job.py (201 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional from ..api.base import PaginatedResult, ServiceName, WorkspaceScopedResourceAPI from ..libs.alibabacloud_paistudio20220112.models import ( AlgorithmSpec, CreateTrainingJobRequest, CreateTrainingJobRequestComputeResource, CreateTrainingJobRequestComputeResourceInstanceSpec, CreateTrainingJobRequestComputeResourceSpotSpec, CreateTrainingJobRequestExperimentConfig, CreateTrainingJobRequestHyperParameters, CreateTrainingJobRequestInputChannels, CreateTrainingJobRequestLabels, CreateTrainingJobRequestOutputChannels, CreateTrainingJobRequestScheduler, CreateTrainingJobRequestSettings, CreateTrainingJobRequestUserVpc, CreateTrainingJobResponseBody, GetTrainingJobRequest, GetTrainingJobResponseBody, ListTrainingJobLogsRequest, ListTrainingJobLogsResponseBody, ListTrainingJobsRequest, ) class TrainingJobAPI(WorkspaceScopedResourceAPI): BACKEND_SERVICE_NAME = ServiceName.PAI_STUDIO _list_method = "list_training_jobs_with_options" _create_method = "create_training_job_with_options" _get_method = "get_training_job_with_options" _list_logs_method = "list_training_job_logs_with_options" # _list_method = "list_training_jobs_with_options" def list( self, page_size: int = 20, page_number: int = 1, order: str = None, sort_by: str = None, status: str = None, training_job_name: str = None, ) -> PaginatedResult: request = ListTrainingJobsRequest( page_size=page_size, page_number=page_number, status=status, training_job_name=training_job_name, order=order, sort_by=sort_by, ) res = self._do_request( method_=self._list_method, tmp_req=request, ) return self.make_paginated_result(res) def get_api_object_by_resource_id(self, resource_id) -> Dict[str, Any]: res: GetTrainingJobResponseBody = self._do_request( method_=self._get_method, training_job_id=resource_id, request=GetTrainingJobRequest(), ) return res.to_map() def get(self, training_job_id) -> Dict[str, Any]: return self.get_api_object_by_resource_id(training_job_id) def create( self, instance_type, instance_count, job_name, spot_spec: Optional[Dict[str, Any]] = None, instance_spec: Optional[Dict[str, str]] = None, resource_id: Optional[str] = None, resource_type: Optional[str] = None, hyperparameters: Optional[Dict[str, Any]] = None, input_channels: Optional[List[Dict[str, Any]]] = None, output_channels: Optional[List[Dict[str, Any]]] = None, environments: Dict[str, str] = None, requirements: List[str] = None, labels: Optional[Dict[str, str]] = None, max_running_in_seconds: Optional[int] = None, description: Optional[str] = None, algorithm_name: Optional[str] = None, algorithm_version: Optional[str] = None, algorithm_provider: Optional[str] = None, algorithm_spec: Optional[Dict[str, Any]] = None, user_vpc_config: Optional[Dict[str, Any]] = None, experiment_config: Optional[Dict[str, Any]] = None, settings: Optional[Dict[str, Any]] = None, ) -> str: """Create a TrainingJob.""" if algorithm_spec and ( algorithm_name or algorithm_version or algorithm_provider ): raise ValueError( "Please provide algorithm_spec or a tuple of (algorithm_name, " "algorithm_version or algorithm_provider), but not both." ) if algorithm_spec: algo_spec = AlgorithmSpec().from_map(algorithm_spec) else: algo_spec = None input_channels = [ CreateTrainingJobRequestInputChannels().from_map(ch) for ch in input_channels ] output_channels = [ CreateTrainingJobRequestOutputChannels().from_map(ch) for ch in output_channels ] if instance_type: spot_spec = ( CreateTrainingJobRequestComputeResourceSpotSpec().from_map(spot_spec) if spot_spec else None ) compute_resource = CreateTrainingJobRequestComputeResource( ecs_count=instance_count, ecs_spec=instance_type, # use_spot_instance=bool(spot_spec), spot_spec=spot_spec, ) elif instance_spec: compute_resource = CreateTrainingJobRequestComputeResource( resource_id=resource_id, instance_count=instance_count, instance_spec=CreateTrainingJobRequestComputeResourceInstanceSpec().from_map( instance_spec ), ) else: raise ValueError("Please provide instance_type or instance_spec.") hyperparameters = hyperparameters or dict() hyper_parameters = [ CreateTrainingJobRequestHyperParameters( name=name, value=str(value), ) for name, value in hyperparameters.items() ] labels = ( [ CreateTrainingJobRequestLabels(key=key, value=value) for key, value in labels.items() ] if labels else None ) scheduler = CreateTrainingJobRequestScheduler( max_running_time_in_seconds=max_running_in_seconds ) request = CreateTrainingJobRequest( algorithm_name=algorithm_name, algorithm_provider=algorithm_provider, algorithm_version=algorithm_version, compute_resource=compute_resource, hyper_parameters=hyper_parameters, input_channels=input_channels, resource_type=resource_type, environments=environments, python_requirements=requirements, labels=labels, output_channels=output_channels, scheduler=scheduler, training_job_description=description, training_job_name=job_name, algorithm_spec=algo_spec, user_vpc=CreateTrainingJobRequestUserVpc().from_map(user_vpc_config), experiment_config=CreateTrainingJobRequestExperimentConfig().from_map( experiment_config ), settings=( CreateTrainingJobRequestSettings().from_map(settings) if settings else None ), ) resp: CreateTrainingJobResponseBody = self._do_request( method_=self._create_method, request=request ) return resp.training_job_id def list_logs( self, training_job_id, worker_id=None, page_size=10, page_number=1, start_time=None, end_time=None, ) -> PaginatedResult: request = ListTrainingJobLogsRequest( page_size=page_size, page_number=page_number, start_time=start_time, end_time=end_time, worker_id=worker_id, ) resp: ListTrainingJobLogsResponseBody = self._do_request( method_=self._list_logs_method, training_job_id=training_job_id, request=request, ) # resp.logs may be None logs = resp.logs or [] total_count = resp.total_count or 0 return PaginatedResult( items=logs, total_count=total_count, )