#  Copyright 2024 Alibaba, Inc. or its affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import os
import posixpath
import time
import typing
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_pascal
from Tea.exceptions import TeaException

from ..api.base import PaginatedResult
from ..common.consts import StoragePathCategory
from ..common.logging import get_logger
from ..common.oss_utils import OssUriObj, is_oss_uri, upload
from ..common.utils import (
    is_dataset_id,
    is_filesystem_uri,
    is_odps_table_uri,
    name_from_base,
    print_table,
    random_str,
    retry,
    to_plain_text,
)
from ..exception import UnexpectedStatusException
from ..session import Session, get_default_session

if typing.TYPE_CHECKING:
    from ..estimator import FileSystemInputBase

logger = get_logger(__name__)


def as_oss_dir_uri(uri: str):
    return uri if uri.endswith("/") else uri + "/"


DEFAULT_OUTPUT_MODEL_CHANNEL_NAME = "model"
DEFAULT_CHECKPOINT_CHANNEL_NAME = "checkpoints"
DEFAULT_TENSORBOARD_CHANNEL_NAME = "tensorboard"


class SpotStrategy(str, Enum):
    SpotWithPriceLimit = "SpotWithPriceLimit"
    SpotAsPriceGo = "SpotAsPriceGo"

    def __repr__(self):
        return self.value


class ResourceType(str, Enum):
    Lingjun = "Lingjun"
    General = "General"


class BaseAPIModel(BaseModel):

    model_config = ConfigDict(
        alias_generator=to_pascal,
        populate_by_name=True,
    )

    def model_dump(self, **kwargs) -> Dict[str, Any]:
        kwargs.update({"by_alias": True, "exclude_none": True})
        return super().model_dump(**kwargs)

    def to_dict(self):
        return self.model_dump()


class TrainingJobStatus(object):
    CreateFailed = "CreateFailed"
    InitializeFailed = "InitializeFailed"
    Succeed = "Succeed"
    Failed = "Failed"
    Terminated = "Terminated"
    Creating = "Creating"
    Created = "Created"
    Initializing = "Initializing"
    Submitted = "Submitted"
    Running = "Running"

    @classmethod
    def completed_status(cls):
        return [
            cls.InitializeFailed,
            cls.Succeed,
            cls.Failed,
            cls.Terminated,
        ]

    @classmethod
    def failed_status(cls):
        return [
            cls.InitializeFailed,
            cls.Failed,
            cls.CreateFailed,
        ]


class UserVpcConfig(BaseAPIModel):
    """UserVpcConfig represents the VPC configuration for the training job instance."""

    vpc_id: str = Field(
        ...,
        description="Specifies the ID of the VPC that training job instance connects to.",
    )
    security_group_id: str = Field(
        ...,
        description="The ID of the security group that training job instances belong to.",
    )
    switch_id: Optional[str] = Field(
        None,
        description="The ID of the vSwitch to which the instance belongs. Defaults to None.",
    )
    extended_cidrs: Optional[List[str]] = Field(
        None,
        description="The CIDR blocks configured for the ENI of the training job instance. "
        "If it is not specified, the CIDR block will be configured as the same as the VPC "
        "network segmentation, which means that the training job instance can access all "
        "resources in the VPC. Defaults to None.",
    )


class ExperimentConfig(BaseAPIModel):
    """ExperimentConfig is used to configure the experiment to which the job belongs."""

    experiment_id: str = Field(
        ...,
        description="Specifies the ID of the experiment that training job instance belongs to.",
    )


class OssLocation(BaseAPIModel):
    """OSS location."""

    bucket: str = Field(..., description="OSS bucket name.")
    key: str = Field(..., description="Object key in the OSS bucket.")
    endpoint: Optional[str] = Field(None, description="OSS service endpoint URL.")


class CodeDir(BaseAPIModel):
    """Source code location"""

    location_value: Union[OssLocation, Dict[str, Any]] = Field(
        ..., description="Location of the code directory."
    )
    location_type: str = Field(
        ..., description="Type of the code directory location, e.g., OSS."
    )


# HyperParameter
class HyperParameter(BaseAPIModel):
    """A hyperparameter for a training job."""

    value: str = Field(..., description="Value of the hyperparameter.")
    name: str = Field(..., description="Name of the hyperparameter.")


class InstanceSpec(BaseAPIModel):
    """Instance resource configuration"""

    memory: str = Field(..., description="Memory allocation for the instance.")
    cpu: str = Field(..., alias="CPU", description="CPU allocation for the instance.")
    gpu: str = Field(..., alias="GPU", description="GPU allocation for the instance.")
    shared_memory: Optional[str] = Field(
        None, description="Shared memory allocation, if applicable."
    )


class ComputeResource(BaseAPIModel):
    """Compute Resource Configuration."""

    ecs_count: Optional[int] = Field(None, description="Number of ECS instances.")
    ecs_spec: Optional[str] = Field(None, description="Specification of ECS instances.")
    instance_count: Optional[int] = Field(None, description="Number of instances.")
    instance_spec: Optional[InstanceSpec] = Field(
        None, description="Specification for instances."
    )


# URI Input and Output
class UriInput(BaseAPIModel):
    """URI Input for a training job."""

    name: str = Field(..., description="Name of the input.")
    input_uri: str = Field(..., description="URI of the input data.")


class UriOutput(BaseAPIModel):
    """URI Output for a training job."""

    name: str = Field(..., description="Name of the output.")
    output_uri: str = Field(..., description="URI of the output data.")


class DatasetConfig(BaseAPIModel):
    """Dataset Configuration"""

    dataset_id: str = Field(..., description="Unique ID of the dataset.")
    name: Optional[str] = Field(None, description="Name of the dataset.")
    dataset_name: Optional[str] = Field(
        None, description="Alternative name of the dataset."
    )


class Channel(BaseAPIModel):
    """Channel Configuration."""

    name: str = Field(..., description="Name of the channel.")
    description: Optional[str] = Field(None, description="Description of the channel.")
    required: Optional[bool] = Field(
        None, description="Indicates if the channel is required."
    )
    supported_channel_types: Optional[List[str]] = Field(
        None, description="Supported types for this channel."
    )
    properties: Optional[Dict[str, Any]] = Field(
        None, description="Additional properties of the channel."
    )


# HyperParameter Definition
class HyperParameterDefinition(BaseAPIModel):
    """HyerParameter Definition."""

    name: str = Field(..., description="Name of the hyperparameter.")
    type: Optional[str] = Field(None, description="Type of the hyperparameter.")
    default_value: Optional[str] = Field(
        None, description="Default value of the hyperparameter."
    )
    description: Optional[str] = Field(
        None, description="Description of the hyperparameter."
    )
    required: bool = Field(
        False, description="Indicates if the hyperparameter is required."
    )


class SchedulerConfig(BaseAPIModel):
    max_running_time_in_seconds: Optional[int] = None


class MetricDefinition(BaseAPIModel):
    description: Optional[str] = Field(None, description="Description of the metric.")
    name: str = Field(..., description="Name of the metric.")
    regex: str = Field(
        ..., description="Regular expression used for capturing the metric."
    )


class AlgorithmSpec(BaseAPIModel):
    """Algorithm Specification."""

    command: List[str] = Field(..., description="Command to run the training job.")
    image: str = Field(..., description="Docker image for the training job.")
    supported_channel_types: List[str] = Field(default_factory=list)
    output_channels: List[Channel] = Field(
        default_factory=list, description="Output channels."
    )
    input_channels: List[Channel] = Field(
        default_factory=list, description="Input channels."
    )
    supports_distributed_training: Optional[bool] = Field(
        True, description="Whether the algorithm supports distributed training."
    )
    supported_instance_types: Optional[List[str]] = Field(
        None, description="Supported instance types."
    )
    metric_definitions: Optional[List[MetricDefinition]] = Field(
        None, description="Metric definitions."
    )
    hyperparameter_definitions: List[HyperParameterDefinition] = Field(
        default_factory=list,
        alias="HyperParameters",
        description="Hyperparameter definitions.",
    )
    job_type: str = Field(default="PyTorchJob")
    code_dir: Optional[CodeDir] = Field(None, description="Source code location.")
    customization: Optional[Dict[str, Any]] = Field(
        None, description="Whether the algorithm supports customize code."
    )


class ModelRecipeSpec(BaseAPIModel):
    compute_resource: Optional[ComputeResource] = None
    hyperparameters: List[HyperParameter] = Field(
        default_factory=list, alias="HyperParameters"
    )
    inputs: List[Union[UriInput, DatasetConfig]] = Field(
        default_factory=list, alias="InputChannels"
    )
    scheduler: Optional[SchedulerConfig] = None
    supported_instance_types: Optional[List[str]] = None
    algorithm_spec: Optional[AlgorithmSpec] = None
    algorithm_version: Optional[str] = None
    algorithm_provider: Optional[str] = None
    algorithm_name: Optional[str] = None
    environments: Optional[Dict[str, str]] = None
    requirements: Optional[List[str]] = None


class SpotSpec(BaseAPIModel):
    spot_strategy: SpotStrategy = Field(
        ...,
        description="Spot instance strategy, support 'SpotWithPriceLimit', 'SpotAsPriceGo'",
    )
    spot_discount_limit: Optional[float] = Field(
        None,
        description="Spot instance discount limit, maximum 2 decimal places, "
        "required when spot_strategy is 'SpotWithPriceLimit'."
        "For example, 0.5 means 50% off the original price.",
    )


class TrainingJob(BaseAPIModel):
    """TrainingJob represents a training job in the PAI service."""

    algorithm_id: Optional[str] = None
    algorithm_name: Optional[str] = None
    algorithm_provider: Optional[str] = None
    algorithm_version: Optional[str] = None
    algorithm_spec: Optional[AlgorithmSpec] = None
    compute_resource: Optional[ComputeResource] = None
    scheduler: Optional[SchedulerConfig] = None
    experiment_config: Optional[Dict[str, Any]] = None
    inputs: List[Union[UriInput, DatasetConfig]] = Field(
        default=list, alias="InputChannels"
    )
    outputs: List[Union[UriOutput, DatasetConfig]] = Field(
        default=list, alias="OutputChannels"
    )
    hyperparameters: List[HyperParameter] = Field(
        default_factory=list, alias="HyperParameters"
    )
    labels: Optional[List[Dict[str, str]]] = Field(default_factory=list)
    training_job_description: Optional[str] = None
    training_job_id: Optional[str] = None
    training_job_name: Optional[str] = None
    workspace_id: Optional[str] = None
    training_job_url: Optional[str] = None
    status: Optional[str] = None
    reason_code: Optional[str] = None
    reason_message: Optional[str] = None

    def __hash__(self):
        return hash(self.training_job_id)

    def __eq__(self, other: "TrainingJob"):
        return (
            isinstance(other, TrainingJob)
            and self.training_job_id == other.training_job_id
        )

    @property
    def id(self):
        return self.training_job_id

    @classmethod
    def get(cls, training_job_id, session: Session = None) -> "TrainingJob":
        session = session or get_default_session()
        res = session.training_job_api.get(training_job_id=training_job_id)
        return cls.model_validate(res)

    @classmethod
    def list(
        cls,
        status: Optional[str] = None,
        session: Optional[Session] = None,
        page_size: int = 50,
        page_number: int = 1,
    ):
        session = session or get_default_session()
        res = session.training_job_api.list(
            status=status, page_size=page_size, page_number=page_number
        )
        return [cls.model_validate(item) for item in res.items]

    def output_path(self, channel_name="model"):
        for output_channel in self.outputs:
            if output_channel.name == channel_name:
                return output_channel.output_uri
        raise RuntimeError(
            f"Output channel is not specified: channel_name={channel_name}"
        )

    @property
    def console_uri(self):
        if not self.training_job_id:
            raise ValueError("The TrainingJob is not submitted")

        return self.training_job_url

    def wait(self, interval: int = 5, show_logs: bool = True):
        session = get_default_session()
        self._refresh_status()

        if show_logs:
            job_log_printer = _TrainingJobLogPrinter(
                training_job_id=self.training_job_id, page_size=20, session=session
            )
            job_log_printer.start()
        else:
            job_log_printer = None
        try:
            while not self.is_completed():
                time.sleep(interval)
        finally:
            if job_log_printer:
                job_log_printer.stop(wait=True)

        self._on_job_completed()

    def _on_job_completed(self):
        # Print an empty line to separate the training job logs and the following logs
        print()
        if self.status == TrainingJobStatus.Succeed:
            print(
                f"Training job ({self.training_job_id}) succeeded, you can check the"
                f" logs/metrics/output in  the console:\n{self.console_uri}"
            )
        elif self.status == TrainingJobStatus.Terminated:
            print(
                f"Training job is ended with status {self.status}: "
                f"reason_code={self.reason_code}, reason_message={self.reason_message}."
                f"Check the training job in the console:\n{self.console_uri}"
            )
        elif self.status in TrainingJobStatus.failed_status():
            print(
                f"Training job ({self.training_job_id}) failed, please check the logs"
                f" in the console: \n{self.console_uri}"
            )

            message = f"TrainingJob failed: name={self.training_job_name}, "
            f"training_job_id={self.training_job_id}, "
            f"reason_code={self.reason_code}, status={self.status}, "
            f"reason_message={self.reason_message}"

            raise UnexpectedStatusException(message=message, status=self.status)

    def _refresh_status(self):
        """Reload the training job from the PAI Service,"""
        session = get_default_session()
        training_job = type(self).model_validate(
            session.training_job_api.get(training_job_id=self.training_job_id)
        )
        self.status = training_job.status

    def is_succeeded(self):
        """Return True if the training job is succeeded"""
        self._refresh_status()
        return self.status == TrainingJobStatus.Succeed

    @retry(wait_secs=10)
    def is_completed(self):
        """Return True if the training job is completed, including failed status"""
        if self.status in TrainingJobStatus.completed_status():
            return True
        self._refresh_status()

        return self.status in TrainingJobStatus.completed_status()


class _TrainingJobLogPrinter(object):
    """A class used to print logs for a training job"""

    executor = ThreadPoolExecutor(5)

    def __init__(
        self, training_job_id: str, page_size=10, session: Optional[Session] = None
    ):
        self.training_job_id = training_job_id
        self.session = session
        self.page_size = page_size
        self._future = None
        self._stop = False

    def _list_logs_api(self, page_number: int = 1):
        try:
            res = self.session.training_job_api.list_logs(
                self.training_job_id,
                page_number=page_number,
                page_size=self.page_size,
            )
            return res
        except TeaException as e:
            # hack: Backend service may raise an exception when the training job
            # instance is not found.
            if e.code == "TRAINING_JOB_INSTANCE_NOT_FOUND":
                return PaginatedResult(items=[], total_count=0)
            else:
                raise e

    def _list_logs(self):
        page_number, page_offset = 1, 0
        # print training job logs.
        while not self._stop:
            res = self._list_logs_api(page_number=page_number)
            # 1. move to next page
            if len(res.items) == self.page_size:
                # print new logs starting from page_offset
                self._print_logs(logs=res.items[page_offset:])
                page_number += 1
                page_offset = 0
            # 2. stay at the current page.
            else:
                if len(res.items) > page_offset:
                    # print new logs starting from page_offset
                    self._print_logs(logs=res.items[page_offset:])
                    page_offset = len(res.items)
                time.sleep(1)

        # When _stop is True, wait and print remaining logs.
        time.sleep(10)
        while True:
            res = self._list_logs_api(page_number=page_number)
            # There maybe more logs in the next page
            if len(res.items) == self.page_size:
                self._print_logs(logs=res.items[page_offset:])
                page_number += 1
                page_offset = 0
            # No more logs in the next page.
            else:
                if len(res.items) > page_offset:
                    self._print_logs(logs=res.items[page_offset:])
                break

    def _print_logs(self, logs: List[str]):
        for log in logs:
            print(log)

    def start(self):
        if self._future:
            raise ValueError("The training job log printer is already started")
        self._stop = False
        self._future = self.executor.submit(self._list_logs)

    def stop(self, wait: bool = True):
        self._stop = True
        if self._future:
            self._future.result()


class _TrainingJobSubmitter(object):
    """A class used to submit a training job to the PAI service."""

    def __init__(
        self,
        base_job_name: Optional[str] = None,
        output_path: Optional[str] = None,
        experiment_config: Optional[ExperimentConfig] = None,
        user_vpc_config: Optional[UserVpcConfig] = None,
        max_run_time: Optional[int] = None,
        instance_type: Optional[str] = None,
        instance_spec: Optional[Dict] = None,
        instance_count: Optional[int] = None,
        resource_id: Optional[Dict] = None,
        resource_type: Optional[Union[str, ResourceType]] = None,
        spot_spec: Optional[SpotSpec] = None,
        environments: Optional[Dict] = None,
        requirements: Optional[List[str]] = None,
        labels: Optional[Dict[str, str]] = None,
        settings: Optional[Dict[str, Any]] = None,
    ):
        self.session = get_default_session()
        self._training_jobs = []
        self.base_job_name = base_job_name or type(self).__name__.lower()
        self.output_path = output_path
        self.user_vpc_config = user_vpc_config
        self.spot_spec = spot_spec
        self.experiment_config = experiment_config
        self.max_run_time = max_run_time
        self.instance_type = instance_type
        self.instance_spec = instance_spec
        self.instance_count = instance_count or 1
        self.resource_id = resource_id
        self.resource_type = ResourceType(resource_type) if resource_type else None
        self.environments = environments
        self.requirements = requirements
        self.settings = settings
        self.labels = labels

    def wait(self, interval: int = 5, show_logs: bool = True, all_jobs: bool = False):
        """Block until the jobs is completed.

        Args:
            interval(int): Interval to reload job status
            show_logs(bool): Specifies whether to fetch and print the logs produced by
                the job.
            all_jobs(bool): Wait latest job or wait all jobs in processor, show_logs disabled while
                wait all jobs.

        Raises:
            RuntimeError: If no job is submitted.

        """
        if all_jobs:
            if not self._training_jobs:
                raise RuntimeError("Could not find any submitted job.")
            remains = set(self._training_jobs)
            while remains:
                for job in self._training_jobs:
                    if job in remains and job.is_completed():
                        remains.remove(job)

                time.sleep(interval)
            self._generate_jobs_report()
        else:
            latest_job = self.latest_job
            if not latest_job:
                raise RuntimeError("Could not find a submitted job.")
            latest_job.wait(interval=interval, show_logs=show_logs)
            return latest_job

    def _generate_jobs_report(self):
        """Generate current jobs report and output to stdout"""
        print(f"Jobs status report, total jobs count: {len(self._training_jobs)}")
        rows = []
        headers = ["JobName", "JobID", "Status"]
        for job in self._training_jobs:
            rows.append([job.training_job_name, job.id, job.status])
        print_table(headers, rows)

    def job_name(self, job_name: Optional[str] = None):
        if job_name:
            return job_name
        sep = "-"
        base_name = self.base_job_name
        return name_from_base(base_name, sep)

    def build_inputs(
        self,
        inputs: Dict[str, Any],
        input_channels: List[Channel],
        default_inputs: Optional[Dict[str, Any]] = None,
    ) -> List[Dict[str, str]]:
        res = []
        inputs = inputs or dict()
        input_channels = input_channels or []
        default_inputs = default_inputs or {}

        inputs = {**default_inputs, **inputs}
        requires = {ch.name for ch in input_channels if ch.required} - set(
            inputs.keys()
        )
        if requires:
            raise ValueError(
                "Required input channels are not provided: {}".format(
                    ",".join(requires)
                )
            )
        for name, item in inputs.items():
            input_config = self._get_input_config(name, item)
            res.append(input_config.model_dump())

        return res

    @staticmethod
    def _default_training_output_channels() -> List[Channel]:
        channels = [
            Channel(
                name=DEFAULT_OUTPUT_MODEL_CHANNEL_NAME,
                description="Training output models",
                required=True,
            ),
            Channel(
                name=DEFAULT_CHECKPOINT_CHANNEL_NAME,
                description="Training checkpoints channel",
                required=False,
            ),
            Channel(
                name=DEFAULT_TENSORBOARD_CHANNEL_NAME,
                properties={"ossAppendable": "true"},
                description="TensorBoard logs channel",
                required=False,
            ),
        ]

        return channels

    def _training_job_base_output(self, job_name):
        job_name = to_plain_text(job_name)
        if self.output_path:
            if not is_oss_uri(self.output_path):
                raise ValueError("Output path should be an OSS path.")
            return os.path.join(self.output_path, f"{job_name}_{random_str(6)}")

        session = get_default_session()
        bucket_name = session.oss_bucket.bucket_name
        storage_path = session.get_storage_path_by_category(
            StoragePathCategory.TrainingJob,
            f"{to_plain_text(job_name)}_{random_str(6)}",
        )
        base_output_path = f"oss://{bucket_name}/{storage_path}"
        return base_output_path

    def build_outputs(
        self,
        job_name: str,
        output_channels: List[Channel],
        outputs: Optional[Dict[str, Any]] = None,
    ) -> List[Dict[str, str]]:
        base_output_path = self._training_job_base_output(job_name)
        res = []
        outputs = outputs or dict()

        for ch in output_channels:
            if ch.name in outputs:
                output = self._get_output_config(name=ch.name, item=outputs[ch.name])
            else:
                output_uri = as_oss_dir_uri(posixpath.join(base_output_path, ch.name))
                output = UriOutput(name=ch.name, output_uri=output_uri)
            res.append(output)

        extra_outputs = set(outputs.keys()) - {ch.name for ch in output_channels}

        for name in extra_outputs:
            output = self._get_output_config(
                name=name,
                item=outputs[name],
            )
            res.append(output)

        return [item.model_dump() for item in res]

    # TODO: get arguments, such as VPCConfig, instance_type etc, from self instance.
    def _submit(
        self,
        job_name: str,
        algorithm_spec: Optional[AlgorithmSpec] = None,
        algorithm_name: Optional[str] = None,
        algorithm_version: Optional[str] = None,
        algorithm_provider: Optional[str] = None,
        instance_count: int = 1,
        instance_type: Optional[str] = None,
        instance_spec: Optional[InstanceSpec] = None,
        resource_id: Optional[str] = None,
        inputs: Optional[List[Dict[str, Any]]] = None,
        outputs: Optional[List[Dict[str, Any]]] = None,
        hyperparameters: Optional[Dict[str, str]] = None,
        max_run_time: Optional[int] = None,
        environments: Optional[Dict[str, str]] = None,
        user_vpc_config: Optional[Dict[str, str]] = None,
        requirements: Optional[List[str]] = None,
        experiment_config: Optional[Dict[str, Any]] = None,
        labels: Optional[Dict[str, str]] = None,
        wait: bool = True,
        show_logs: bool = False,
    ):
        session = get_default_session()

        if not self.resource_type or self.resource_type == ResourceType.General:
            resource_type = None
        else:
            resource_type = self.resource_type.value

        if self.spot_spec:
            spot_spec = {
                "SpotStrategy": self.spot_spec.spot_strategy.value,
            }
            if self.spot_spec.spot_discount_limit:
                spot_spec["SpotDiscountLimit"] = self.spot_spec.spot_discount_limit
        else:
            spot_spec = None

        # user vpc
        if self.user_vpc_config:
            user_vpc_config = {
                "VpcId": self.user_vpc_config.vpc_id,
                "SecurityGroupId": self.user_vpc_config.security_group_id,
            }
        else:
            user_vpc_config = None

        training_job_id = session.training_job_api.create(
            instance_count=instance_count,
            instance_spec=instance_spec.model_dump() if instance_spec else None,
            algorithm_name=algorithm_name,
            algorithm_provider=algorithm_provider,
            experiment_config=(
                experiment_config.model_dump()
                if experiment_config and isinstance(experiment_config, ExperimentConfig)
                else experiment_config
            ),
            spot_spec=spot_spec,
            algorithm_version=algorithm_version,
            instance_type=instance_type,
            resource_id=resource_id,
            resource_type=resource_type,
            job_name=job_name,
            hyperparameters=hyperparameters,
            max_running_in_seconds=max_run_time,
            input_channels=inputs,
            output_channels=outputs,
            algorithm_spec=algorithm_spec.model_dump() if algorithm_spec else None,
            requirements=requirements,
            user_vpc_config=user_vpc_config,
            labels=labels,
            environments=environments,
            settings=self.settings,
        )
        training_job = TrainingJob.get(training_job_id)
        self._training_jobs.append(training_job)
        print(
            f"View the job detail by accessing the console URI: {training_job.console_uri}"
        )
        if wait:
            training_job.wait(show_logs=show_logs)
        return training_job

    @classmethod
    def _get_input_config(
        cls, name: str, item: Union[str, "FileSystemInputBase", DatasetConfig]
    ) -> Union[UriInput, DatasetConfig]:
        """Get input uri for training_job from given input."""
        from pai.estimator import FileSystemInputBase

        if not isinstance(item, (str, FileSystemInputBase, DatasetConfig)):
            raise ValueError(f"Input data of type {type(item)} is not supported.")

        if isinstance(item, FileSystemInputBase):
            input_ = UriInput(
                name=name,
                input_uri=item.to_input_uri(),
            )
        elif isinstance(item, DatasetConfig):
            input_ = DatasetConfig(
                name=name,
                dataset_id=item.dataset_id,
            )
        elif is_oss_uri(item) or is_filesystem_uri(item) or is_odps_table_uri(item):
            input_ = UriInput(
                name=name,
                input_uri=item,
            )
        elif isinstance(item, str):
            if os.path.exists(item):
                store_path = Session.get_storage_path_by_category(
                    StoragePathCategory.InputData
                )
                input_ = UriInput(name=name, input_uri=upload(item, store_path))
            else:
                raise ValueError("Invalid input data path, file not found: {item}.")
        else:
            raise ValueError(
                f"Invalid input data, supported inputs are OSS, NAS, MaxCompute "
                f"table or local path: {type(item)}."
            )
        return input_

    @classmethod
    def _get_output_config(
        cls, name: str, item: str
    ) -> Union[UriOutput, DatasetConfig]:
        from pai.estimator import FileSystemInputBase

        if not isinstance(item, (str, FileSystemInputBase, DatasetConfig)):
            raise ValueError(f"Output data of type {type(item)} is not supported.")

        if isinstance(item, FileSystemInputBase):
            output = UriOutput(
                name=name,
                output_uri=item.to_input_uri(),
            )
        elif isinstance(item, DatasetConfig):
            output = DatasetConfig(name=name, dataset_id=item.dataset_id)
        elif is_oss_uri(item) or is_filesystem_uri(item) or is_odps_table_uri(item):
            output = UriOutput(
                name=name,
                output_uri=as_oss_dir_uri(item),
            )
        else:
            raise ValueError(
                "Invalid output data, supported outputs are OSS, NAS, MaxCompute "
            )

        return output

    @property
    def latest_job(self) -> "TrainingJob":
        return self._training_jobs[-1] if self._training_jobs else None

    def _build_code_input(
        self, job_name: str, source_dir: Optional[str], code_dest: Optional[str] = None
    ) -> Optional[CodeDir]:
        """Upload source files to OSS and return the code input for training job."""
        if not source_dir:
            return
        if is_oss_uri(source_dir):
            code_uri = source_dir
        elif not os.path.exists(source_dir):
            raise ValueError(f"Source directory {source_dir} does not exist.")
        else:
            code_dest = code_dest or self.session.get_storage_path_by_category(
                StoragePathCategory.TrainingSrc, to_plain_text(job_name)
            )
            code_uri = upload(
                source_path=source_dir,
                oss_path=code_dest,
                bucket=self.session.oss_bucket,
            )
        oss_uri_obj = OssUriObj(uri=self.session.patch_oss_endpoint(code_uri))
        code_dir = CodeDir(
            location_type="oss",
            location_value=OssLocation(
                bucket=oss_uri_obj.bucket_name,
                key=oss_uri_obj.object_key,
                endpoint=oss_uri_obj.endpoint,
            ),
        )

        return code_dir
