#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import logging
from datetime import datetime, timezone
from typing import List

import numpy as np
import tensorflow as tf

from basic_pitch import models
from basic_pitch.callbacks import VisualizeCallback
from basic_pitch.constants import DATASET_SAMPLING_FREQUENCY
from basic_pitch.data import tf_example_deserialization

logging.basicConfig(level=logging.INFO)


def main(
    source: str,
    output: str,
    batch_size: int,
    shuffle_size: int,
    learning_rate: float,
    epochs: int,
    steps_per_epoch: int,
    validation_steps: int,
    size_evaluation_callback_datasets: int,
    datasets_to_use: List[str],
    dataset_sampling_frequency: np.ndarray,
    no_sonify: bool,
    no_contours: bool,
    weighted_onset_loss: bool,
    positive_onset_weight: float,
) -> None:
    """Parse config and run training or evaluation.

    Args:
        source: source directory for data
        output: output directory for trained model / checkpoints / tensorboard
        batch_size: batch size for data.
        shuffle_size: size of shuffle buffer (only for training set) for the data shuffling mechanism
        learning_rate: learning_rate for training
        epochs: number of epochs to train for
        steps_per_epoch: the number of batches to process per epoch during training
        validation_steps: the number of validation batches to evaluate on per epoch
        size_evaluation_callback_datasets: the batch size to use for visualization / logging
        datasets_to_use: which datasets to train / evaluate on e.g. guitarset, medleydb_pitch, slakh
        dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
            are sampled from during training / validation dataset creation.
        no_sonify: Whether or not to include sonifications in tensorboard.
        no_contours: Whether or not to include contours in the output.
        weighted_onset_loss: whether or not to use a weighted cross entropy loss.
        positive_onset_weight: weighting factor for the positive labels.
    """
    # configuration.add_externals()
    logging.info(f"source directory: {source}")
    logging.info(f"output directory: {output}")
    logging.info(f"tensorflow version: {tf.__version__}")
    logging.info("parameters to train.main() function:")
    logging.info(f"batch_size: {batch_size}")
    logging.info(f"shuffle_size: {shuffle_size}")
    logging.info(f"learning_rate: {learning_rate}")
    logging.info(f"epochs: {epochs}")
    logging.info(f"steps_per_epoch: {steps_per_epoch}")
    logging.info(f"validation_steps: {validation_steps}")
    logging.info(f"size_evaluation_callback_datasets: {size_evaluation_callback_datasets}")
    logging.info(f"using datasets: {datasets_to_use} with frequencies {dataset_sampling_frequency}")
    logging.info(f"no_contours: {no_contours}")
    logging.info(f"weighted_onset_loss: {weighted_onset_loss}")
    logging.info(f"positive_onset_weight: {positive_onset_weight}")

    # model
    model = models.model(no_contours=no_contours)
    input_shape = list(model.input_shape)
    if input_shape[0] is None:
        input_shape[0] = batch_size
    logging.info("input_shape" + str(input_shape))

    output_shape = model.output_shape
    for k, v in output_shape.items():
        output_shape[k] = list(v)
        if v[0] is None:
            output_shape[k][0] = batch_size
    logging.info("output_shape" + str(output_shape))
    # data loaders
    train_ds, validation_ds = tf_example_deserialization.prepare_datasets(
        source,
        shuffle_size,
        batch_size,
        validation_steps,
        datasets_to_use,
        dataset_sampling_frequency,
    )

    MAX_EVAL_CBF_BATCH_SIZE = 4
    (
        train_visualization_ds,
        validation_visualization_ds,
    ) = tf_example_deserialization.prepare_visualization_datasets(
        source,
        batch_size=min(size_evaluation_callback_datasets, MAX_EVAL_CBF_BATCH_SIZE),
        validation_steps=max(1, size_evaluation_callback_datasets // MAX_EVAL_CBF_BATCH_SIZE),
        datasets_to_use=datasets_to_use,
        dataset_sampling_frequency=dataset_sampling_frequency,
    )

    timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M")
    tensorboard_log_dir = os.path.join(output, timestamp, "tensorboard")
    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1),
        tf.keras.callbacks.EarlyStopping(patience=25, verbose=2),
        tf.keras.callbacks.ReduceLROnPlateau(verbose=1, patience=10, factor=0.5),
        tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(output, timestamp, "model.best"), save_best_only=True),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(output, timestamp, "checkpoints", "model.{epoch:02d}")
        ),
        VisualizeCallback(
            train_visualization_ds,
            validation_visualization_ds,
            tensorboard_log_dir,
            not no_sonify,
            not no_contours,
        ),
    ]

    # if no_contours:
    #     loss = models.loss_no_contour(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)
    # else:
    #     loss = models.loss(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)
    loss = models.loss(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)

    # train
    model.compile(
        loss=loss,
        optimizer=tf.keras.optimizers.Adam(learning_rate),
        sample_weight_mode={"contour": None, "note": None, "onset": None},
    )

    logging.info("--- Model Training specs ---")
    logging.info(f"  train_ds: {train_ds}")
    logging.info(f"  validation_ds: {validation_ds}")
    model.summary()

    model.fit(
        train_ds,
        epochs=epochs,
        callbacks=callbacks,
        steps_per_epoch=steps_per_epoch,
        validation_data=validation_ds,
        validation_steps=validation_steps,
    )


def console_entry_point() -> None:
    """From pip installed script."""
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--source", required=True, help="Path to directory containing train/validation splits.")
    parser.add_argument("--output", required=True, help="Directory to save the model in.")
    parser.add_argument("-e", "--epochs", type=int, default=500, help="Number of training epochs.")
    parser.add_argument(
        "-b",
        "--batch-size",
        type=int,
        default=16,
        help="batch size of training. Unlike Estimator API, this specifies the batch size per-GPU.",
    )
    parser.add_argument(
        "-l",
        "--learning-rate",
        type=float,
        default=0.001,
        help="ADAM optimizer learning rate",
    )
    parser.add_argument(
        "-s",
        "--steps-per-epoch",
        type=int,
        default=100,
        help="steps_per_epoch (batch) of each training loop",
    )
    parser.add_argument(
        "-v",
        "--validation-steps",
        type=int,
        default=10,
        help="validation steps (number of BATCHES) for each validation run. MUST be a positive integer",
    )
    parser.add_argument(
        "-z",
        "--training-shuffle-size",
        type=int,
        default=100,
        help="training dataset shuffle size",
    )
    parser.add_argument(
        "--size-evaluation-callback-datasets",
        type=int,
        default=4,
        help="number of elements in the dataset used by the evaluation callback function",
    )
    for dataset in DATASET_SAMPLING_FREQUENCY.keys():
        parser.add_argument(
            f"--{dataset.lower()}",
            action="store_true",
            default=False,
            help=f"Use {dataset} dataset in training",
        )
    parser.add_argument(
        "--no-sonify",
        action="store_true",
        default=False,
        help="if given, exclude sonifications from the tensorboard / data visualization",
    )
    parser.add_argument(
        "--no-contours",
        action="store_true",
        default=False,
        help="if given, trains without supervising the contour layer",
    )
    parser.add_argument(
        "--weighted-onset-loss",
        action="store_true",
        default=False,
        help="if given, trains onsets with a class-weighted loss",
    )
    parser.add_argument(
        "--positive-onset-weight",
        type=float,
        default=0.5,
        help="Positive class onset weight. Only applies when weignted onset loss is true.",
    )

    args = parser.parse_args()
    datasets_to_use = [
        dataset.lower()
        for dataset in DATASET_SAMPLING_FREQUENCY.keys()
        if getattr(args, dataset.lower().replace("-", "_"))
    ]
    dataset_sampling_frequency = np.array(
        [
            frequency
            for dataset, frequency in DATASET_SAMPLING_FREQUENCY.items()
            if getattr(args, dataset.lower().replace("-", "_"))
        ]
    )
    dataset_sampling_frequency = dataset_sampling_frequency / np.sum(dataset_sampling_frequency)

    assert args.steps_per_epoch is not None
    assert args.validation_steps > 0

    main(
        args.source,
        args.output,
        args.training_shuffle_size,
        args.batch_size,
        args.learning_rate,
        args.epochs,
        args.steps_per_epoch,
        args.validation_steps,
        args.size_evaluation_callback_datasets,
        datasets_to_use,
        dataset_sampling_frequency,
        args.dont_sonify,
        args.no_contours,
        args.weighted_onset_loss,
        args.positive_onset_weight,
    )


if __name__ == "__main__":
    console_entry_point()
