#  Copyright 2023 Alibaba, Inc. or its affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import six

from pai.common.yaml_utils import dump as yaml_dump
from pai.common.yaml_utils import safe_load as yaml_safe_load
from pai.pipeline.component._base import ComponentBase, UnRegisteredComponent
from pai.pipeline.types.spec import load_input_output_spec
from pai.session import get_default_session


class RegisteredComponent(ComponentBase):
    """RegisteredComponent represent the pipeline schema from pipeline/component.

    RegisteredComponent object include the definition of "Workflow" use in PAI pipeline service.
     It could be fetched from remote pipeline service or construct from local Pipeline/Component.
     Saved pipeline component has unique `pipeline_id` which is generated by pipeline service.

    """

    def __init__(self, pipeline_id, manifest=None, workspace_id=None):
        """Template constructor.

        Args:
            manifest: "Workflow" definition of the pipeline.
            pipeline_id: Unique ID for pipeline in PAI service.
            workspace_id: ID of the workspace which the pipeline belongs to.
        """
        if not manifest:
            session = get_default_session()
            manifest = session.pipeline_api.get_schema()["Manifest"]
        if isinstance(manifest, six.string_types):
            manifest = yaml_safe_load(manifest)

        self._manifest = manifest
        self._pipeline_id = pipeline_id
        self._workspace_id = workspace_id

        inputs, outputs = load_input_output_spec(self, manifest["spec"])

        self._identifier = manifest["metadata"]["identifier"]
        self._provider = manifest["metadata"]["provider"]
        self._version = manifest["metadata"]["version"]
        super(RegisteredComponent, self).__init__(
            inputs=inputs,
            outputs=outputs,
        )

    def __repr__(self):
        return "%s:Id=%s,Identifier=%s,Provider=%s,Version=%s" % (
            type(self).__name__,
            self._pipeline_id,
            self.identifier,
            self.provider,
            self.version,
        )

    def __eq__(self, other):
        return isinstance(other, type(self)) and other.pipeline_id == self.pipeline_id

    @property
    def identifier(self):
        return self._identifier

    @property
    def provider(self):
        return self._provider

    @property
    def version(self):
        return self._version

    @property
    def pipeline_id(self):
        """Unique ID of the pipeline in PAI pipeline service.

        Returns:
            str: Unique pipeline ID of the component instance.

        """
        return self._pipeline_id

    @property
    def manifest(self):
        """Pipeline manifest schema.

        Returns:
            dict: Pipeline manifest schema in dict.
        """
        return self._manifest

    @property
    def raw_manifest(self):
        """Pipeline manifest in YAML format

        Returns:
            str: Pipeline manifest.
        """
        return yaml_dump(self._manifest)

    @classmethod
    def get_by_identifier(cls, identifier, provider=None, version="v1"):
        """Get SavedOperator with identifier-provider-version tuple.

        Args:
            identifier (str): Pipeline identifier.
            provider (str): Provider of the Pipeline, account uid of the current session will be used as
              default.
            version (str): Version of the pipeline.

        Returns:
            pai.pipeline.SavedOperator: SavedOperator instance

        """
        session = get_default_session()
        provider = provider or session.provider

        res = session.pipeline_api.get_by_identifier(
            identifier=identifier,
            provider=provider,
            version=version,
        )
        if not res:
            raise ValueError(
                f"Not found the specific pipeline/component: "
                f"identifier={identifier} provider={provider} version={version}"
            )
        pipeline_info = session.pipeline_api.get_schema(pipeline_id=res["PipelineId"])
        if not pipeline_info:
            raise ValueError(
                "Not found pipeline with specific information: identifier={0}, provider={1}, version={2}".format(
                    identifier, provider, version
                )
            )

        return cls(
            manifest=pipeline_info["Manifest"],
            pipeline_id=pipeline_info["PipelineId"],
        )

    @classmethod
    def list(
        cls,
        identifier=None,
        provider=None,
        version=None,
        session=None,
        page_size=10,
        page_number=1,
    ):
        """List the SavedOperator in PAI

        Search the pipeline component available in remote PAI service. The method return a
        generator used to traverse the SavedOperator set match the query condition.

        Args:
            identifier (str): Pipeline identifier filter.
            provider (str): Pipeline provider filter.
            version (str): Pipeline version.
            workspace_id (str): Workspace id of the pipeline.
            session: PAI session.
            page_number (int):
            page_size (int):

        Yields:
              pai.component.SavedOperator: SavedOperator match the query.
        """
        session = session or get_default_session()

        if not provider:
            provider = session.provider

        result = session.pipeline_api.list(
            identifier=identifier,
            provider=provider,
            version=version,
            page_size=page_size,
            page_number=page_number,
        )

        return result.items

    def update(self, component):
        """Update current registered component/pipeline using the manifest of given component/pipeline.

        Args:
            component (Union[UnRegisteredOperator, str, dict]): New pipeline/component spec,
            could be an unregistered component, dict or yaml in str.
        """
        session = get_default_session()

        if isinstance(component, UnRegisteredComponent):
            manifest = component.to_manifest(
                identifier=self.identifier, version=self.version
            )
        elif isinstance(component, str):
            manifest = component
        elif isinstance(component, dict):
            manifest = yaml_dump(component)

        else:
            raise ValueError(
                "Please provider ContainerOperator, Pipeline or Manifest in string to update the registered component."
            )

        session.pipeline_api.update(self._pipeline_id, manifest)

    def delete(self):
        """Delete this registered component/pipeline."""
        get_default_session().pipeline_api.delete(self.pipeline_id)

    @classmethod
    def deserialize(cls, obj_dict):
        manifest, id, workspace_id = (
            obj_dict.get("Manifest"),
            obj_dict["PipelineId"],
            obj_dict.get("WorkspaceId", None),
        )
        return cls(
            workspace_id=workspace_id,
            pipeline_id=id,
            manifest=manifest,
        )

    @classmethod
    def _has_impl(cls, manifest):
        if isinstance(manifest, six.string_types):
            manifest = yaml_safe_load(manifest)

        if "spec" not in manifest:
            return False
        spec = manifest["spec"]
        if "pipelines" not in spec and "container" not in spec:
            return False
        return True

    @classmethod
    def get(cls, pipeline_id, session=None):
        """Get SavedOperator with pipeline_id.

        Args:
            pipeline_id (str): Unique pipeline id.

        Returns:
            pai.pipeline.SavedOperator: SavedOperator instance with the
             specific pipeline_id

        """
        session = session or get_default_session()
        return cls.deserialize(session.pipeline_api.get_schema(pipeline_id=pipeline_id))

    def save(self, identifier=None, version=None):
        raise NotImplementedError("SaveTemplate is not savable.")

    def _submit(self, job_name, args):
        from pai.pipeline.run import PipelineRun

        session = get_default_session()

        run_id = PipelineRun.run(
            job_name,
            args,
            no_confirm_required=True,
            pipeline_id=self._pipeline_id,
            session=session,
        )

        return run_id

    def io_spec_to_dict(self):
        return {
            "inputs": self.inputs.to_dict(),
            "outputs": self.outputs.to_dict(),
        }
