#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import enum
import json
import logging
import os
import pathlib
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast


from basic_pitch import CT_PRESENT, ICASSP_2022_MODEL_PATH, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT

try:
    import tensorflow as tf
except ImportError:
    pass

try:
    import coremltools as ct
except ImportError:
    pass

try:
    import tflite_runtime.interpreter as tflite
except ImportError:
    if TF_PRESENT:
        import tensorflow.lite as tflite

try:
    import onnxruntime as ort
except ImportError:
    pass

import numpy as np
import numpy.typing as npt
import librosa
import pretty_midi

from basic_pitch.constants import (
    AUDIO_SAMPLE_RATE,
    AUDIO_N_SAMPLES,
    ANNOTATIONS_FPS,
    FFT_HOP,
)
from basic_pitch.commandline_printing import (
    generating_file_message,
    no_tf_warnings,
    file_saved_confirmation,
    failed_to_save,
)
import basic_pitch.note_creation as infer


class Model:
    class MODEL_TYPES(enum.Enum):
        TENSORFLOW = enum.auto()
        COREML = enum.auto()
        TFLITE = enum.auto()
        ONNX = enum.auto()

    def __init__(self, model_path: Union[pathlib.Path, str]):
        present = []
        if TF_PRESENT:
            present.append("TensorFlow")
            try:
                self.model_type = Model.MODEL_TYPES.TENSORFLOW
                self.model = tf.saved_model.load(str(model_path))
                return
            except Exception as e:
                if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.listdir(model_path)):
                    logging.warning(
                        "Could not load TensorFlow saved model %s even "
                        "though it looks like a saved model file with error %s. "
                        "Are you sure it's a TensorFlow saved model?",
                        model_path,
                        e.__repr__(),
                    )

        if CT_PRESENT:
            present.append("CoreML")
            try:
                self.model_type = Model.MODEL_TYPES.COREML
                self.model = ct.models.MLModel(str(model_path), compute_units=ct.ComputeUnit.CPU_ONLY)
                return
            except Exception as e:
                if str(model_path).endswith(".mlpackage"):
                    logging.warning(
                        "Could not load CoreML file %s even "
                        "though it looks like a CoreML file with error %s. "
                        "Are you sure it's a CoreML file?",
                        model_path,
                        e.__repr__(),
                    )

        if TFLITE_PRESENT or TF_PRESENT:
            present.append("TensorFlowLite")
            try:
                self.model_type = Model.MODEL_TYPES.TFLITE
                self.interpreter = tflite.Interpreter(str(model_path))
                self.model = self.interpreter.get_signature_runner()
                return
            except Exception as e:
                if str(model_path).endswith(".tflite"):
                    logging.warning(
                        "Could not load TensorFlowLite file %s even "
                        "though it looks like a TFLite file with error %s. "
                        "Are you sure it's a TFLite file?",
                        model_path,
                        e.__repr__(),
                    )

        if ONNX_PRESENT:
            present.append("ONNX")
            try:
                self.model_type = Model.MODEL_TYPES.ONNX
                providers = ["CPUExecutionProvider"]
                if "CUDAExecutionProvider" in ort.get_available_providers():
                    providers.insert(0, "CUDAExecutionProvider")
                self.model = ort.InferenceSession(str(model_path), providers=providers)
                return
            except Exception as e:
                if str(model_path).endswith(".onnx"):
                    logging.warning(
                        "Could not load ONNX file %s even "
                        "though it looks like a ONNX file with error %s. "
                        "Are you sure it's a ONNX file?",
                        model_path,
                        e.__repr__(),
                    )

        raise ValueError(
            f"File {model_path} cannot be loaded into either "
            "TensorFlow, CoreML, TFLite or ONNX. "
            "Please check if it is a supported and valid serialized model "
            "and that one of these packages are installed. On this system, "
            f"{present} is installed."
        )

    def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]:
        if self.model_type == Model.MODEL_TYPES.TENSORFLOW:
            return {k: v.numpy() for k, v in cast(tf.keras.Model, self.model(x)).items()}
        elif self.model_type == Model.MODEL_TYPES.COREML:
            result = cast(ct.models.MLModel, self.model).predict({"input_2": x})
            return {
                "note": result["Identity_1"],
                "onset": result["Identity_2"],
                "contour": result["Identity"],
            }
        elif self.model_type == Model.MODEL_TYPES.TFLITE:
            return self.model(input_2=x)  # type: ignore
        elif self.model_type == Model.MODEL_TYPES.ONNX:
            return {
                k: v
                for k, v in zip(
                    ["note", "onset", "contour"],
                    cast(ort.InferenceSession, self.model).run(
                        [
                            "StatefulPartitionedCall:1",
                            "StatefulPartitionedCall:2",
                            "StatefulPartitionedCall:0",
                        ],
                        {"serving_default_input_2:0": x},
                    ),
                )
            }


def window_audio_file(
    audio_original: npt.NDArray[np.float32], hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float]]]:
    """
    Pad appropriately an audio file, and return as
    windowed signal, with window length = AUDIO_N_SAMPLES

    Returns:
        audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
            audio windowed into fixed length chunks
        window_times: list of {'start':.., 'end':...} objects (times in seconds)

    """
    for i in range(0, audio_original.shape[0], hop_size):
        window = audio_original[i : i + AUDIO_N_SAMPLES]
        if len(window) < AUDIO_N_SAMPLES:
            window = np.pad(
                window,
                pad_width=[[0, AUDIO_N_SAMPLES - len(window)]],
            )
        t_start = float(i) / AUDIO_SAMPLE_RATE
        window_time = {
            "start": t_start,
            "end": t_start + (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE),
        }
        yield np.expand_dims(window, axis=-1), window_time


def get_audio_input(
    audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
    """
    Read wave file (as mono), pad appropriately, and return as
    windowed signal, with window length = AUDIO_N_SAMPLES

    Returns:
        audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
            audio windowed into fixed length chunks
        window_times: list of {'start':.., 'end':...} objects (times in seconds)
        audio_original_length: int
            length of original audio file, in frames, BEFORE padding.

    """
    assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"

    audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)

    original_length = audio_original.shape[0]
    audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
    for window, window_time in window_audio_file(audio_original, hop_size):
        yield np.expand_dims(window, axis=0), window_time, original_length


def unwrap_output(
    output: npt.NDArray[np.float32],
    audio_original_length: int,
    n_overlapping_frames: int,
) -> np.array:
    """Unwrap batched model predictions to a single matrix.

    Args:
        output: array (n_batches, n_times_short, n_freqs)
        audio_original_length: length of original audio signal (in samples)
        n_overlapping_frames: number of overlapping frames in the output

    Returns:
        array (n_times, n_freqs)
    """
    if len(output.shape) != 3:
        return None

    n_olap = int(0.5 * n_overlapping_frames)
    if n_olap > 0:
        # remove half of the overlapping frames from beginning and end
        output = output[:, n_olap:-n_olap, :]

    output_shape = output.shape
    n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
    unwrapped_output = output.reshape(output_shape[0] * output_shape[1], output_shape[2])
    return unwrapped_output[:n_output_frames_original, :]  # trim to original audio length


def run_inference(
    audio_path: Union[pathlib.Path, str],
    model_or_model_path: Union[Model, pathlib.Path, str],
    debug_file: Optional[pathlib.Path] = None,
) -> Dict[str, np.array]:
    """Run the model on the input audio path.

    Args:
        audio_path: The audio to run inference on.
        model_or_model_path: A loaded Model or path to a serialized model to load.
        debug_file: An optional path to output debug data to. Useful for testing/verification.

    Returns:
       A dictionary with the notes, onsets and contours from model inference.
    """
    if isinstance(model_or_model_path, Model):
        model = model_or_model_path
    else:
        model = Model(model_or_model_path)

    # overlap 30 frames
    n_overlapping_frames = 30
    overlap_len = n_overlapping_frames * FFT_HOP
    hop_size = AUDIO_N_SAMPLES - overlap_len

    output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
    for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):
        for k, v in model.predict(audio_windowed).items():
            output[k].append(v)

    unwrapped_output = {
        k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
    }

    if debug_file:
        with open(debug_file, "w") as f:
            json.dump(
                {
                    "audio_windowed": audio_windowed.numpy().tolist(),
                    "audio_original_length": audio_original_length,
                    "hop_size_samples": hop_size,
                    "overlap_length_samples": overlap_len,
                    "unwrapped_output": {k: v.tolist() for k, v in unwrapped_output.items()},
                },
                f,
            )

    return unwrapped_output


class OutputExtensions(enum.Enum):
    MIDI = "mid"
    MODEL_OUTPUT_NPZ = "npz"
    MIDI_SONIFICATION = "wav"
    NOTE_EVENTS = "csv"


def verify_input_path(audio_path: Union[pathlib.Path, str]) -> None:
    """Verify that an input path is valid and can be processed

    Args:
        audio_path: Path to an audio file.

    Raises:
        ValueError: If the audio file is invalid.
    """
    if not os.path.isfile(audio_path):
        raise ValueError(f"🚨 {audio_path} is not a file path.")

    if not os.path.exists(audio_path):
        raise ValueError(f"🚨 {audio_path} does not exist.")


def verify_output_dir(output_dir: Union[pathlib.Path, str]) -> None:
    """Verify that an output directory is valid and can be processed

    Args:
        output_dir: Path to an output directory.

    Raises:
        ValueError: If the output directory is invalid.
    """
    if not os.path.isdir(output_dir):
        raise ValueError(f"🚨 {output_dir} is not a directory.")

    if not os.path.exists(output_dir):
        raise ValueError(f"🚨 {output_dir} does not exist.")


def build_output_path(
    audio_path: Union[pathlib.Path, str],
    output_directory: Union[pathlib.Path, str],
    output_type: OutputExtensions,
) -> pathlib.Path:
    """Create an output path and make sure it doesn't already exist.

    Args:
        audio_path: The original file path.
        output_directory: The directory we will output to.
        output_type: The type of output file we are creating.

    Raises:
        IOError: If the generated path already exists.

    Returns:
        A new path in the output_directory with the stem audio_path and an extension
        based on output_type.
    """
    audio_path = str(audio_path)
    if not isinstance(output_directory, pathlib.Path):
        output_directory = pathlib.Path(output_directory)

    basename, _ = os.path.splitext(os.path.basename(audio_path))

    output_path = output_directory / f"{basename}_basic_pitch.{output_type.value}"

    generating_file_message(output_type.name)

    if output_path.exists():
        raise IOError(
            f"  🚨 {str(output_path)} already exists and would be overwritten. Skipping output files for {audio_path}."
        )

    return output_path


def save_note_events(
    note_events: List[Tuple[float, float, int, float, Optional[List[int]]]],
    save_path: Union[pathlib.Path, str],
) -> None:
    """Save note events to file

    Args:
        note_events: A list of note event tuples to save. Tuples have the format
            ("start_time_s", "end_time_s", "pitch_midi", "velocity", "list of pitch bend values")
        save_path: The location we're saving it
    """

    with open(save_path, "w") as fhandle:
        writer = csv.writer(fhandle, delimiter=",")
        writer.writerow(["start_time_s", "end_time_s", "pitch_midi", "velocity", "pitch_bend"])
        for start_time, end_time, note_number, amplitude, pitch_bend in note_events:
            row = [start_time, end_time, note_number, int(np.round(127 * amplitude))]
            if pitch_bend:
                row.extend(pitch_bend)
            writer.writerow(row)


def predict(
    audio_path: Union[pathlib.Path, str],
    model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH,
    onset_threshold: float = 0.5,
    frame_threshold: float = 0.3,
    minimum_note_length: float = 127.70,
    minimum_frequency: Optional[float] = None,
    maximum_frequency: Optional[float] = None,
    multiple_pitch_bends: bool = False,
    melodia_trick: bool = True,
    debug_file: Optional[pathlib.Path] = None,
    midi_tempo: float = 120,
) -> Tuple[
    Dict[str, np.array],
    pretty_midi.PrettyMIDI,
    List[Tuple[float, float, int, float, Optional[List[int]]]],
]:
    """Run a single prediction.

    Args:
        audio_path: File path for the audio to run inference on.
        model_or_model_path: A loaded Model or path to a serialized model to load.
        onset_threshold: Minimum energy required for an onset to be considered present.
        frame_threshold: Minimum energy requirement for a frame to be considered present.
        minimum_note_length: The minimum allowed note length in milliseconds.
        minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
        maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
        multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
        melodia_trick: Use the melodia post-processing step.
        debug_file: An optional path to output debug data to. Useful for testing/verification.
    Returns:
        The model output, midi data and note events from a single prediction
    """

    with no_tf_warnings():
        print(f"Predicting MIDI for {audio_path}...")

        model_output = run_inference(audio_path, model_or_model_path, debug_file)
        min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
        midi_data, note_events = infer.model_output_to_notes(
            model_output,
            onset_thresh=onset_threshold,
            frame_thresh=frame_threshold,
            min_note_len=min_note_len,  # convert to frames
            min_freq=minimum_frequency,
            max_freq=maximum_frequency,
            multiple_pitch_bends=multiple_pitch_bends,
            melodia_trick=melodia_trick,
            midi_tempo=midi_tempo,
        )

    if debug_file:
        with open(debug_file) as f:
            debug_data = json.load(f)
        with open(debug_file, "w") as f:
            json.dump(
                {
                    **debug_data,
                    "min_note_length": min_note_len,
                    "onset_thresh": onset_threshold,
                    "frame_thresh": frame_threshold,
                    "estimated_notes": [
                        (
                            float(start_time),
                            float(end_time),
                            int(pitch),
                            float(amplitude),
                            [int(b) for b in pitch_bends] if pitch_bends else None,
                        )
                        for start_time, end_time, pitch, amplitude, pitch_bends in note_events
                    ],
                },
                f,
            )

    return model_output, midi_data, note_events


def predict_and_save(
    audio_path_list: Sequence[Union[pathlib.Path, str]],
    output_directory: Union[pathlib.Path, str],
    save_midi: bool,
    sonify_midi: bool,
    save_model_outputs: bool,
    save_notes: bool,
    model_or_model_path: Union[Model, str, pathlib.Path],
    onset_threshold: float = 0.5,
    frame_threshold: float = 0.3,
    minimum_note_length: float = 127.70,
    minimum_frequency: Optional[float] = None,
    maximum_frequency: Optional[float] = None,
    multiple_pitch_bends: bool = False,
    melodia_trick: bool = True,
    debug_file: Optional[pathlib.Path] = None,
    sonification_samplerate: int = 44100,
    midi_tempo: float = 120,
) -> None:
    """Make a prediction and save the results to file.

    Args:
        audio_path_list: List of file paths for the audio to run inference on.
        output_directory: Directory to output MIDI and all other outputs derived from the model to.
        save_midi: True to save midi.
        sonify_midi: Whether or not to render audio from the MIDI and output it to a file.
        save_model_outputs: True to save contours, onsets and notes from the model prediction.
        save_notes: True to save note events.
        model_or_model_path: A loaded Model or path to a serialized model to load.
        onset_threshold: Minimum energy required for an onset to be considered present.
        frame_threshold: Minimum energy requirement for a frame to be considered present.
        minimum_note_length: The minimum allowed note length in milliseconds.
        minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
        maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
        multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
        melodia_trick: Use the melodia post-processing step.
        debug_file: An optional path to output debug data to. Useful for testing/verification.
        sonification_samplerate: Sample rate for rendering audio from MIDI.
    """
    for audio_path in audio_path_list:
        print("")
        try:
            model_output, midi_data, note_events = predict(
                pathlib.Path(audio_path),
                model_or_model_path,
                onset_threshold,
                frame_threshold,
                minimum_note_length,
                minimum_frequency,
                maximum_frequency,
                multiple_pitch_bends,
                melodia_trick,
                debug_file,
                midi_tempo,
            )

            if save_model_outputs:
                model_output_path = build_output_path(audio_path, output_directory, OutputExtensions.MODEL_OUTPUT_NPZ)
                try:
                    np.savez(model_output_path, basic_pitch_model_output=model_output)
                    file_saved_confirmation(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)
                except Exception as e:
                    failed_to_save(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)
                    raise e

            if save_midi:
                try:
                    midi_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI)
                except IOError as e:
                    raise e
                try:
                    midi_data.write(str(midi_path))
                    file_saved_confirmation(OutputExtensions.MIDI.name, midi_path)
                except Exception as e:
                    failed_to_save(OutputExtensions.MIDI.name, midi_path)
                    raise e

            if sonify_midi:
                midi_sonify_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI_SONIFICATION)
                try:
                    infer.sonify_midi(midi_data, midi_sonify_path, sr=sonification_samplerate)
                    file_saved_confirmation(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)
                except Exception as e:
                    failed_to_save(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)
                    raise e

            if save_notes:
                note_events_path = build_output_path(audio_path, output_directory, OutputExtensions.NOTE_EVENTS)
                try:
                    save_note_events(note_events, note_events_path)
                    file_saved_confirmation(OutputExtensions.NOTE_EVENTS.name, note_events_path)
                except Exception as e:
                    failed_to_save(OutputExtensions.NOTE_EVENTS.name, note_events_path)
                    raise e
        except Exception as e:
            raise e
