################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import pickle
import time
from importlib import import_module
from typing import List, Dict, Any

import cloudpickle
from pyflink.table import StreamTableEnvironment

from pyflink.ml.api import Stage


def save_pipeline(pipeline: Stage, stages: List[Stage], path: str) -> None:
    """
    Saves a Pipeline or PipelineModel with the given list of stages to the given path.

    :param pipeline: A Pipeline or PipelineModel instance.
    :param stages: A list of stages of the given pipeline.
    :param path: The parent directory to save the pipeline metadata and its stages.
    """
    # Creates parent directories if not already created.
    os.makedirs(path, exist_ok=True)

    extra_metadata = {'num_stages': len(stages)}
    save_metadata(pipeline, path, extra_metadata)

    num_stages = len(stages)
    for i, stage in enumerate(stages):
        stage_path = get_path_for_pipeline_stage(i, num_stages, path)
        stage.save(stage_path)


def load_pipeline(t_env: StreamTableEnvironment, path: str) -> List[Stage]:
    """
    Loads the stages of a Pipeline or PipelineModel from the given path.

    :param t_env: A StreamTableEnvironment instance.
    :param path: A StreamExecutionEnvironment instance.
    :return: A list of stages.
    """
    meta_data = load_metadata(path)
    num_stages = meta_data['num_stages']
    return [load_stage(t_env, get_path_for_pipeline_stage(i, num_stages, path))
            for i in range(num_stages)]


def save_metadata(stage: Stage, path: str, extra_metadata=None) -> None:
    """
    Saves the metadata of the given stage and the extra metadata to a file named `metadata` under
    the given path. The metadata of a stage includes the stage class name, parameter values etc.

    Required: the metadata file under the given path should not exist.

    :param stage: The stage instance.
    :param path: The parent directory to save the stage metadata.
    :param extra_metadata: The extra metadata to be saved.
    """
    if extra_metadata is None:
        extra_metadata = {}
    os.makedirs(path, exist_ok=True)

    metadata = {k: v for k, v in extra_metadata.items()}
    metadata['module_name'] = str(stage.__module__)
    metadata['class_name'] = str(type(stage).__name__)
    metadata['timestamp'] = time.time()
    metadata['param_map'] = {cloudpickle.dumps(k): k.json_encode(v)
                             for k, v in stage.get_param_map().items()}
    # TODO: add version in the metadata.

    metadata_bytes = pickle.dumps(metadata)
    metadata_path = os.path.join(path, 'metadata')
    if os.path.isfile(metadata_path):
        raise IOError(f'File {metadata_path} already exists.')
    with open(metadata_path, 'wb') as fd:
        fd.write(metadata_bytes)


def load_metadata(path: str) -> Dict[str, Any]:
    """
    Loads the metadata from the metadata file under the given path.

    :param path: The parent directory of the metadata file to read from.
    :return: A Dict from metadata name to metadata value.
    """
    metadata_path = os.path.join(path, "metadata")
    with open(metadata_path, 'rb') as fd:
        metadata_bytes = fd.read()
    meta_data = pickle.loads(metadata_bytes)
    return meta_data


def load_stage(t_env: StreamTableEnvironment, path: str) -> Stage:
    """
    Loads the stage from the given path by invoking the static load() method of the stage. The
    stage module name and class name are read from the metadata file under the given path. The
    load() method is expected to construct the stage instance with the saved parameters, model data
    and other metadata if exists.

    :param t_env: A StreamTableEnvironment instance.
    :param path: The parent directory of the stage metadata file.
    :return: An instance of Stage.
    """
    metadata = load_metadata(path)
    module_name = metadata.get('module_name')
    class_name = metadata.get('class_name')
    stage_class = getattr(import_module(module_name), class_name)
    return stage_class.load(t_env, path)


def load_stage_param(path: str) -> Stage:
    """
    Loads the stage with the saved parameters from the given path. This method reads the metadata
    file under the given path, instantiates the stage using its no-argument constructor, and
    loads the stage with the param_map from the metadata file.

    Note: This method does not attempt to read model data from the given path. Caller needs to
    read model data from the given path if the stage has model data.

    :param path: The parent directory of the stage metadata file.
    :return: An stage instance.
    """
    metadata = load_metadata(path)
    module_name = metadata.get('module_name')
    class_name = metadata.get('class_name')
    stage_class = getattr(import_module(module_name), class_name)
    stage = stage_class()

    param_map = metadata.get('param_map')
    for k, v in param_map.items():
        param = cloudpickle.loads(k)
        value = param.json_decode(v)
        stage.set(param, value)
    return stage


def get_path_for_pipeline_stage(stage_idx: int, num_stages: int, parent_path: str) -> str:
    """
    Returns a string with value {parent_path}/stages/{stage_idx}, where the stage_idx is prefixed
    with zero or more `0` to have the same length as num_stages. The resulting string can be used
    as the directory to save a stage of the Pipeline or PipelineModel.
    """
    format_str = ("{:0>%sd}" % (num_stages,))
    return os.path.abspath(os.path.join(parent_path, "stages", format_str.format(stage_idx)))
