pai/job/_training_job.py (730 lines of code) (raw):
# Copyright 2024 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import posixpath
import time
import typing
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_pascal
from Tea.exceptions import TeaException
from ..api.base import PaginatedResult
from ..common.consts import StoragePathCategory
from ..common.logging import get_logger
from ..common.oss_utils import OssUriObj, is_oss_uri, upload
from ..common.utils import (
is_dataset_id,
is_filesystem_uri,
is_odps_table_uri,
name_from_base,
print_table,
random_str,
retry,
to_plain_text,
)
from ..exception import UnexpectedStatusException
from ..session import Session, get_default_session
if typing.TYPE_CHECKING:
from ..estimator import FileSystemInputBase
logger = get_logger(__name__)
def as_oss_dir_uri(uri: str):
return uri if uri.endswith("/") else uri + "/"
DEFAULT_OUTPUT_MODEL_CHANNEL_NAME = "model"
DEFAULT_CHECKPOINT_CHANNEL_NAME = "checkpoints"
DEFAULT_TENSORBOARD_CHANNEL_NAME = "tensorboard"
class SpotStrategy(str, Enum):
SpotWithPriceLimit = "SpotWithPriceLimit"
SpotAsPriceGo = "SpotAsPriceGo"
def __repr__(self):
return self.value
class ResourceType(str, Enum):
Lingjun = "Lingjun"
General = "General"
class BaseAPIModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_pascal,
populate_by_name=True,
)
def model_dump(self, **kwargs) -> Dict[str, Any]:
kwargs.update({"by_alias": True, "exclude_none": True})
return super().model_dump(**kwargs)
def to_dict(self):
return self.model_dump()
class TrainingJobStatus(object):
CreateFailed = "CreateFailed"
InitializeFailed = "InitializeFailed"
Succeed = "Succeed"
Failed = "Failed"
Terminated = "Terminated"
Creating = "Creating"
Created = "Created"
Initializing = "Initializing"
Submitted = "Submitted"
Running = "Running"
@classmethod
def completed_status(cls):
return [
cls.InitializeFailed,
cls.Succeed,
cls.Failed,
cls.Terminated,
]
@classmethod
def failed_status(cls):
return [
cls.InitializeFailed,
cls.Failed,
cls.CreateFailed,
]
class UserVpcConfig(BaseAPIModel):
"""UserVpcConfig represents the VPC configuration for the training job instance."""
vpc_id: str = Field(
...,
description="Specifies the ID of the VPC that training job instance connects to.",
)
security_group_id: str = Field(
...,
description="The ID of the security group that training job instances belong to.",
)
switch_id: Optional[str] = Field(
None,
description="The ID of the vSwitch to which the instance belongs. Defaults to None.",
)
extended_cidrs: Optional[List[str]] = Field(
None,
description="The CIDR blocks configured for the ENI of the training job instance. "
"If it is not specified, the CIDR block will be configured as the same as the VPC "
"network segmentation, which means that the training job instance can access all "
"resources in the VPC. Defaults to None.",
)
class ExperimentConfig(BaseAPIModel):
"""ExperimentConfig is used to configure the experiment to which the job belongs."""
experiment_id: str = Field(
...,
description="Specifies the ID of the experiment that training job instance belongs to.",
)
class OssLocation(BaseAPIModel):
"""OSS location."""
bucket: str = Field(..., description="OSS bucket name.")
key: str = Field(..., description="Object key in the OSS bucket.")
endpoint: Optional[str] = Field(None, description="OSS service endpoint URL.")
class CodeDir(BaseAPIModel):
"""Source code location"""
location_value: Union[OssLocation, Dict[str, Any]] = Field(
..., description="Location of the code directory."
)
location_type: str = Field(
..., description="Type of the code directory location, e.g., OSS."
)
# HyperParameter
class HyperParameter(BaseAPIModel):
"""A hyperparameter for a training job."""
value: str = Field(..., description="Value of the hyperparameter.")
name: str = Field(..., description="Name of the hyperparameter.")
class InstanceSpec(BaseAPIModel):
"""Instance resource configuration"""
memory: str = Field(..., description="Memory allocation for the instance.")
cpu: str = Field(..., alias="CPU", description="CPU allocation for the instance.")
gpu: str = Field(..., alias="GPU", description="GPU allocation for the instance.")
shared_memory: Optional[str] = Field(
None, description="Shared memory allocation, if applicable."
)
class ComputeResource(BaseAPIModel):
"""Compute Resource Configuration."""
ecs_count: Optional[int] = Field(None, description="Number of ECS instances.")
ecs_spec: Optional[str] = Field(None, description="Specification of ECS instances.")
instance_count: Optional[int] = Field(None, description="Number of instances.")
instance_spec: Optional[InstanceSpec] = Field(
None, description="Specification for instances."
)
# URI Input and Output
class UriInput(BaseAPIModel):
"""URI Input for a training job."""
name: str = Field(..., description="Name of the input.")
input_uri: str = Field(..., description="URI of the input data.")
class UriOutput(BaseAPIModel):
"""URI Output for a training job."""
name: str = Field(..., description="Name of the output.")
output_uri: str = Field(..., description="URI of the output data.")
class DatasetConfig(BaseAPIModel):
"""Dataset Configuration"""
dataset_id: str = Field(..., description="Unique ID of the dataset.")
name: Optional[str] = Field(None, description="Name of the dataset.")
dataset_name: Optional[str] = Field(
None, description="Alternative name of the dataset."
)
class Channel(BaseAPIModel):
"""Channel Configuration."""
name: str = Field(..., description="Name of the channel.")
description: Optional[str] = Field(None, description="Description of the channel.")
required: Optional[bool] = Field(
None, description="Indicates if the channel is required."
)
supported_channel_types: Optional[List[str]] = Field(
None, description="Supported types for this channel."
)
properties: Optional[Dict[str, Any]] = Field(
None, description="Additional properties of the channel."
)
# HyperParameter Definition
class HyperParameterDefinition(BaseAPIModel):
"""HyerParameter Definition."""
name: str = Field(..., description="Name of the hyperparameter.")
type: Optional[str] = Field(None, description="Type of the hyperparameter.")
default_value: Optional[str] = Field(
None, description="Default value of the hyperparameter."
)
description: Optional[str] = Field(
None, description="Description of the hyperparameter."
)
required: bool = Field(
False, description="Indicates if the hyperparameter is required."
)
class SchedulerConfig(BaseAPIModel):
max_running_time_in_seconds: Optional[int] = None
class MetricDefinition(BaseAPIModel):
description: Optional[str] = Field(None, description="Description of the metric.")
name: str = Field(..., description="Name of the metric.")
regex: str = Field(
..., description="Regular expression used for capturing the metric."
)
class AlgorithmSpec(BaseAPIModel):
"""Algorithm Specification."""
command: List[str] = Field(..., description="Command to run the training job.")
image: str = Field(..., description="Docker image for the training job.")
supported_channel_types: List[str] = Field(default_factory=list)
output_channels: List[Channel] = Field(
default_factory=list, description="Output channels."
)
input_channels: List[Channel] = Field(
default_factory=list, description="Input channels."
)
supports_distributed_training: Optional[bool] = Field(
True, description="Whether the algorithm supports distributed training."
)
supported_instance_types: Optional[List[str]] = Field(
None, description="Supported instance types."
)
metric_definitions: Optional[List[MetricDefinition]] = Field(
None, description="Metric definitions."
)
hyperparameter_definitions: List[HyperParameterDefinition] = Field(
default_factory=list,
alias="HyperParameters",
description="Hyperparameter definitions.",
)
job_type: str = Field(default="PyTorchJob")
code_dir: Optional[CodeDir] = Field(None, description="Source code location.")
customization: Optional[Dict[str, Any]] = Field(
None, description="Whether the algorithm supports customize code."
)
class ModelRecipeSpec(BaseAPIModel):
compute_resource: Optional[ComputeResource] = None
hyperparameters: List[HyperParameter] = Field(
default_factory=list, alias="HyperParameters"
)
inputs: List[Union[UriInput, DatasetConfig]] = Field(
default_factory=list, alias="InputChannels"
)
scheduler: Optional[SchedulerConfig] = None
supported_instance_types: Optional[List[str]] = None
algorithm_spec: Optional[AlgorithmSpec] = None
algorithm_version: Optional[str] = None
algorithm_provider: Optional[str] = None
algorithm_name: Optional[str] = None
environments: Optional[Dict[str, str]] = None
requirements: Optional[List[str]] = None
class SpotSpec(BaseAPIModel):
spot_strategy: SpotStrategy = Field(
...,
description="Spot instance strategy, support 'SpotWithPriceLimit', 'SpotAsPriceGo'",
)
spot_discount_limit: Optional[float] = Field(
None,
description="Spot instance discount limit, maximum 2 decimal places, "
"required when spot_strategy is 'SpotWithPriceLimit'."
"For example, 0.5 means 50% off the original price.",
)
class TrainingJob(BaseAPIModel):
"""TrainingJob represents a training job in the PAI service."""
algorithm_id: Optional[str] = None
algorithm_name: Optional[str] = None
algorithm_provider: Optional[str] = None
algorithm_version: Optional[str] = None
algorithm_spec: Optional[AlgorithmSpec] = None
compute_resource: Optional[ComputeResource] = None
scheduler: Optional[SchedulerConfig] = None
experiment_config: Optional[Dict[str, Any]] = None
inputs: List[Union[UriInput, DatasetConfig]] = Field(
default=list, alias="InputChannels"
)
outputs: List[Union[UriOutput, DatasetConfig]] = Field(
default=list, alias="OutputChannels"
)
hyperparameters: List[HyperParameter] = Field(
default_factory=list, alias="HyperParameters"
)
labels: Optional[List[Dict[str, str]]] = Field(default_factory=list)
training_job_description: Optional[str] = None
training_job_id: Optional[str] = None
training_job_name: Optional[str] = None
workspace_id: Optional[str] = None
training_job_url: Optional[str] = None
status: Optional[str] = None
reason_code: Optional[str] = None
reason_message: Optional[str] = None
def __hash__(self):
return hash(self.training_job_id)
def __eq__(self, other: "TrainingJob"):
return (
isinstance(other, TrainingJob)
and self.training_job_id == other.training_job_id
)
@property
def id(self):
return self.training_job_id
@classmethod
def get(cls, training_job_id, session: Session = None) -> "TrainingJob":
session = session or get_default_session()
res = session.training_job_api.get(training_job_id=training_job_id)
return cls.model_validate(res)
@classmethod
def list(
cls,
status: Optional[str] = None,
session: Optional[Session] = None,
page_size: int = 50,
page_number: int = 1,
):
session = session or get_default_session()
res = session.training_job_api.list(
status=status, page_size=page_size, page_number=page_number
)
return [cls.model_validate(item) for item in res.items]
def output_path(self, channel_name="model"):
for output_channel in self.outputs:
if output_channel.name == channel_name:
return output_channel.output_uri
raise RuntimeError(
f"Output channel is not specified: channel_name={channel_name}"
)
@property
def console_uri(self):
if not self.training_job_id:
raise ValueError("The TrainingJob is not submitted")
return self.training_job_url
def wait(self, interval: int = 5, show_logs: bool = True):
session = get_default_session()
self._refresh_status()
if show_logs:
job_log_printer = _TrainingJobLogPrinter(
training_job_id=self.training_job_id, page_size=20, session=session
)
job_log_printer.start()
else:
job_log_printer = None
try:
while not self.is_completed():
time.sleep(interval)
finally:
if job_log_printer:
job_log_printer.stop(wait=True)
self._on_job_completed()
def _on_job_completed(self):
# Print an empty line to separate the training job logs and the following logs
print()
if self.status == TrainingJobStatus.Succeed:
print(
f"Training job ({self.training_job_id}) succeeded, you can check the"
f" logs/metrics/output in the console:\n{self.console_uri}"
)
elif self.status == TrainingJobStatus.Terminated:
print(
f"Training job is ended with status {self.status}: "
f"reason_code={self.reason_code}, reason_message={self.reason_message}."
f"Check the training job in the console:\n{self.console_uri}"
)
elif self.status in TrainingJobStatus.failed_status():
print(
f"Training job ({self.training_job_id}) failed, please check the logs"
f" in the console: \n{self.console_uri}"
)
message = f"TrainingJob failed: name={self.training_job_name}, "
f"training_job_id={self.training_job_id}, "
f"reason_code={self.reason_code}, status={self.status}, "
f"reason_message={self.reason_message}"
raise UnexpectedStatusException(message=message, status=self.status)
def _refresh_status(self):
"""Reload the training job from the PAI Service,"""
session = get_default_session()
training_job = type(self).model_validate(
session.training_job_api.get(training_job_id=self.training_job_id)
)
self.status = training_job.status
def is_succeeded(self):
"""Return True if the training job is succeeded"""
self._refresh_status()
return self.status == TrainingJobStatus.Succeed
@retry(wait_secs=10)
def is_completed(self):
"""Return True if the training job is completed, including failed status"""
if self.status in TrainingJobStatus.completed_status():
return True
self._refresh_status()
return self.status in TrainingJobStatus.completed_status()
class _TrainingJobLogPrinter(object):
"""A class used to print logs for a training job"""
executor = ThreadPoolExecutor(5)
def __init__(
self, training_job_id: str, page_size=10, session: Optional[Session] = None
):
self.training_job_id = training_job_id
self.session = session
self.page_size = page_size
self._future = None
self._stop = False
def _list_logs_api(self, page_number: int = 1):
try:
res = self.session.training_job_api.list_logs(
self.training_job_id,
page_number=page_number,
page_size=self.page_size,
)
return res
except TeaException as e:
# hack: Backend service may raise an exception when the training job
# instance is not found.
if e.code == "TRAINING_JOB_INSTANCE_NOT_FOUND":
return PaginatedResult(items=[], total_count=0)
else:
raise e
def _list_logs(self):
page_number, page_offset = 1, 0
# print training job logs.
while not self._stop:
res = self._list_logs_api(page_number=page_number)
# 1. move to next page
if len(res.items) == self.page_size:
# print new logs starting from page_offset
self._print_logs(logs=res.items[page_offset:])
page_number += 1
page_offset = 0
# 2. stay at the current page.
else:
if len(res.items) > page_offset:
# print new logs starting from page_offset
self._print_logs(logs=res.items[page_offset:])
page_offset = len(res.items)
time.sleep(1)
# When _stop is True, wait and print remaining logs.
time.sleep(10)
while True:
res = self._list_logs_api(page_number=page_number)
# There maybe more logs in the next page
if len(res.items) == self.page_size:
self._print_logs(logs=res.items[page_offset:])
page_number += 1
page_offset = 0
# No more logs in the next page.
else:
if len(res.items) > page_offset:
self._print_logs(logs=res.items[page_offset:])
break
def _print_logs(self, logs: List[str]):
for log in logs:
print(log)
def start(self):
if self._future:
raise ValueError("The training job log printer is already started")
self._stop = False
self._future = self.executor.submit(self._list_logs)
def stop(self, wait: bool = True):
self._stop = True
if self._future:
self._future.result()
class _TrainingJobSubmitter(object):
"""A class used to submit a training job to the PAI service."""
def __init__(
self,
base_job_name: Optional[str] = None,
output_path: Optional[str] = None,
experiment_config: Optional[ExperimentConfig] = None,
user_vpc_config: Optional[UserVpcConfig] = None,
max_run_time: Optional[int] = None,
instance_type: Optional[str] = None,
instance_spec: Optional[Dict] = None,
instance_count: Optional[int] = None,
resource_id: Optional[Dict] = None,
resource_type: Optional[Union[str, ResourceType]] = None,
spot_spec: Optional[SpotSpec] = None,
environments: Optional[Dict] = None,
requirements: Optional[List[str]] = None,
labels: Optional[Dict[str, str]] = None,
settings: Optional[Dict[str, Any]] = None,
):
self.session = get_default_session()
self._training_jobs = []
self.base_job_name = base_job_name or type(self).__name__.lower()
self.output_path = output_path
self.user_vpc_config = user_vpc_config
self.spot_spec = spot_spec
self.experiment_config = experiment_config
self.max_run_time = max_run_time
self.instance_type = instance_type
self.instance_spec = instance_spec
self.instance_count = instance_count or 1
self.resource_id = resource_id
self.resource_type = ResourceType(resource_type) if resource_type else None
self.environments = environments
self.requirements = requirements
self.settings = settings
self.labels = labels
def wait(self, interval: int = 5, show_logs: bool = True, all_jobs: bool = False):
"""Block until the jobs is completed.
Args:
interval(int): Interval to reload job status
show_logs(bool): Specifies whether to fetch and print the logs produced by
the job.
all_jobs(bool): Wait latest job or wait all jobs in processor, show_logs disabled while
wait all jobs.
Raises:
RuntimeError: If no job is submitted.
"""
if all_jobs:
if not self._training_jobs:
raise RuntimeError("Could not find any submitted job.")
remains = set(self._training_jobs)
while remains:
for job in self._training_jobs:
if job in remains and job.is_completed():
remains.remove(job)
time.sleep(interval)
self._generate_jobs_report()
else:
latest_job = self.latest_job
if not latest_job:
raise RuntimeError("Could not find a submitted job.")
latest_job.wait(interval=interval, show_logs=show_logs)
return latest_job
def _generate_jobs_report(self):
"""Generate current jobs report and output to stdout"""
print(f"Jobs status report, total jobs count: {len(self._training_jobs)}")
rows = []
headers = ["JobName", "JobID", "Status"]
for job in self._training_jobs:
rows.append([job.training_job_name, job.id, job.status])
print_table(headers, rows)
def job_name(self, job_name: Optional[str] = None):
if job_name:
return job_name
sep = "-"
base_name = self.base_job_name
return name_from_base(base_name, sep)
def build_inputs(
self,
inputs: Dict[str, Any],
input_channels: List[Channel],
default_inputs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
res = []
inputs = inputs or dict()
input_channels = input_channels or []
default_inputs = default_inputs or {}
inputs = {**default_inputs, **inputs}
requires = {ch.name for ch in input_channels if ch.required} - set(
inputs.keys()
)
if requires:
raise ValueError(
"Required input channels are not provided: {}".format(
",".join(requires)
)
)
for name, item in inputs.items():
input_config = self._get_input_config(name, item)
res.append(input_config.model_dump())
return res
@staticmethod
def _default_training_output_channels() -> List[Channel]:
channels = [
Channel(
name=DEFAULT_OUTPUT_MODEL_CHANNEL_NAME,
description="Training output models",
required=True,
),
Channel(
name=DEFAULT_CHECKPOINT_CHANNEL_NAME,
description="Training checkpoints channel",
required=False,
),
Channel(
name=DEFAULT_TENSORBOARD_CHANNEL_NAME,
properties={"ossAppendable": "true"},
description="TensorBoard logs channel",
required=False,
),
]
return channels
def _training_job_base_output(self, job_name):
job_name = to_plain_text(job_name)
if self.output_path:
if not is_oss_uri(self.output_path):
raise ValueError("Output path should be an OSS path.")
return os.path.join(self.output_path, f"{job_name}_{random_str(6)}")
session = get_default_session()
bucket_name = session.oss_bucket.bucket_name
storage_path = session.get_storage_path_by_category(
StoragePathCategory.TrainingJob,
f"{to_plain_text(job_name)}_{random_str(6)}",
)
base_output_path = f"oss://{bucket_name}/{storage_path}"
return base_output_path
def build_outputs(
self,
job_name: str,
output_channels: List[Channel],
outputs: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, str]]:
base_output_path = self._training_job_base_output(job_name)
res = []
outputs = outputs or dict()
for ch in output_channels:
if ch.name in outputs:
output = self._get_output_config(name=ch.name, item=outputs[ch.name])
else:
output_uri = as_oss_dir_uri(posixpath.join(base_output_path, ch.name))
output = UriOutput(name=ch.name, output_uri=output_uri)
res.append(output)
extra_outputs = set(outputs.keys()) - {ch.name for ch in output_channels}
for name in extra_outputs:
output = self._get_output_config(
name=name,
item=outputs[name],
)
res.append(output)
return [item.model_dump() for item in res]
# TODO: get arguments, such as VPCConfig, instance_type etc, from self instance.
def _submit(
self,
job_name: str,
algorithm_spec: Optional[AlgorithmSpec] = None,
algorithm_name: Optional[str] = None,
algorithm_version: Optional[str] = None,
algorithm_provider: Optional[str] = None,
instance_count: int = 1,
instance_type: Optional[str] = None,
instance_spec: Optional[InstanceSpec] = None,
resource_id: Optional[str] = None,
inputs: Optional[List[Dict[str, Any]]] = None,
outputs: Optional[List[Dict[str, Any]]] = None,
hyperparameters: Optional[Dict[str, str]] = None,
max_run_time: Optional[int] = None,
environments: Optional[Dict[str, str]] = None,
user_vpc_config: Optional[Dict[str, str]] = None,
requirements: Optional[List[str]] = None,
experiment_config: Optional[Dict[str, Any]] = None,
labels: Optional[Dict[str, str]] = None,
wait: bool = True,
show_logs: bool = False,
):
session = get_default_session()
if not self.resource_type or self.resource_type == ResourceType.General:
resource_type = None
else:
resource_type = self.resource_type.value
if self.spot_spec:
spot_spec = {
"SpotStrategy": self.spot_spec.spot_strategy.value,
}
if self.spot_spec.spot_discount_limit:
spot_spec["SpotDiscountLimit"] = self.spot_spec.spot_discount_limit
else:
spot_spec = None
# user vpc
if self.user_vpc_config:
user_vpc_config = {
"VpcId": self.user_vpc_config.vpc_id,
"SecurityGroupId": self.user_vpc_config.security_group_id,
}
else:
user_vpc_config = None
training_job_id = session.training_job_api.create(
instance_count=instance_count,
instance_spec=instance_spec.model_dump() if instance_spec else None,
algorithm_name=algorithm_name,
algorithm_provider=algorithm_provider,
experiment_config=(
experiment_config.model_dump()
if experiment_config and isinstance(experiment_config, ExperimentConfig)
else experiment_config
),
spot_spec=spot_spec,
algorithm_version=algorithm_version,
instance_type=instance_type,
resource_id=resource_id,
resource_type=resource_type,
job_name=job_name,
hyperparameters=hyperparameters,
max_running_in_seconds=max_run_time,
input_channels=inputs,
output_channels=outputs,
algorithm_spec=algorithm_spec.model_dump() if algorithm_spec else None,
requirements=requirements,
user_vpc_config=user_vpc_config,
labels=labels,
environments=environments,
settings=self.settings,
)
training_job = TrainingJob.get(training_job_id)
self._training_jobs.append(training_job)
print(
f"View the job detail by accessing the console URI: {training_job.console_uri}"
)
if wait:
training_job.wait(show_logs=show_logs)
return training_job
@classmethod
def _get_input_config(
cls, name: str, item: Union[str, "FileSystemInputBase", DatasetConfig]
) -> Union[UriInput, DatasetConfig]:
"""Get input uri for training_job from given input."""
from pai.estimator import FileSystemInputBase
if not isinstance(item, (str, FileSystemInputBase, DatasetConfig)):
raise ValueError(f"Input data of type {type(item)} is not supported.")
if isinstance(item, FileSystemInputBase):
input_ = UriInput(
name=name,
input_uri=item.to_input_uri(),
)
elif isinstance(item, DatasetConfig):
input_ = DatasetConfig(
name=name,
dataset_id=item.dataset_id,
)
elif is_oss_uri(item) or is_filesystem_uri(item) or is_odps_table_uri(item):
input_ = UriInput(
name=name,
input_uri=item,
)
elif isinstance(item, str):
if os.path.exists(item):
store_path = Session.get_storage_path_by_category(
StoragePathCategory.InputData
)
input_ = UriInput(name=name, input_uri=upload(item, store_path))
else:
raise ValueError("Invalid input data path, file not found: {item}.")
else:
raise ValueError(
f"Invalid input data, supported inputs are OSS, NAS, MaxCompute "
f"table or local path: {type(item)}."
)
return input_
@classmethod
def _get_output_config(
cls, name: str, item: str
) -> Union[UriOutput, DatasetConfig]:
from pai.estimator import FileSystemInputBase
if not isinstance(item, (str, FileSystemInputBase, DatasetConfig)):
raise ValueError(f"Output data of type {type(item)} is not supported.")
if isinstance(item, FileSystemInputBase):
output = UriOutput(
name=name,
output_uri=item.to_input_uri(),
)
elif isinstance(item, DatasetConfig):
output = DatasetConfig(name=name, dataset_id=item.dataset_id)
elif is_oss_uri(item) or is_filesystem_uri(item) or is_odps_table_uri(item):
output = UriOutput(
name=name,
output_uri=as_oss_dir_uri(item),
)
else:
raise ValueError(
"Invalid output data, supported outputs are OSS, NAS, MaxCompute "
)
return output
@property
def latest_job(self) -> "TrainingJob":
return self._training_jobs[-1] if self._training_jobs else None
def _build_code_input(
self, job_name: str, source_dir: Optional[str], code_dest: Optional[str] = None
) -> Optional[CodeDir]:
"""Upload source files to OSS and return the code input for training job."""
if not source_dir:
return
if is_oss_uri(source_dir):
code_uri = source_dir
elif not os.path.exists(source_dir):
raise ValueError(f"Source directory {source_dir} does not exist.")
else:
code_dest = code_dest or self.session.get_storage_path_by_category(
StoragePathCategory.TrainingSrc, to_plain_text(job_name)
)
code_uri = upload(
source_path=source_dir,
oss_path=code_dest,
bucket=self.session.oss_bucket,
)
oss_uri_obj = OssUriObj(uri=self.session.patch_oss_endpoint(code_uri))
code_dir = CodeDir(
location_type="oss",
location_value=OssLocation(
bucket=oss_uri_obj.bucket_name,
key=oss_uri_obj.object_key,
endpoint=oss_uri_obj.endpoint,
),
)
return code_dir