#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import uuid
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple

import numpy as np
import tensorflow as tf

# import tensorflow_addons as tfa

from basic_pitch.constants import (
    ANNOTATIONS_FPS,
    ANNOT_N_FRAMES,
    AUDIO_N_CHANNELS,
    AUDIO_N_SAMPLES,
    AUDIO_SAMPLE_RATE,
    AUDIO_WINDOW_LENGTH,
    N_FREQ_BINS_NOTES,
    N_FREQ_BINS_CONTOURS,
    Split,
)

N_SAMPLES_PER_TRACK = 20


def prepare_datasets(
    datasets_base_path: str,
    training_shuffle_buffer_size: int,
    batch_size: int,
    validation_steps: int,
    datasets_to_use: List[str],
    dataset_sampling_frequency: np.ndarray,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """
    Return a training and a testing dataset.

    Args:
        datasets_base_path: path to tfrecords for input data
        training_shuffle_buffer_size: size of shuffle buffer (only for training set)
        batch_size: batch size for training and validation
        validation_steps: number of batches to use for validation
        datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset
        dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
            are sampled from during training / validation dataset creation.

    Returns:
        training and validation datasets derived from the underlying tfrecord data
    """
    assert batch_size > 0
    assert validation_steps is not None and validation_steps > 0
    assert training_shuffle_buffer_size is not None

    # init both
    ds_train = sample_datasets(
        Split.train,
        datasets_base_path,
        datasets=datasets_to_use,
        dataset_sampling_frequency=dataset_sampling_frequency,
    )
    ds_validation = sample_datasets(
        Split.validation,
        datasets_base_path,
        datasets=datasets_to_use,
        dataset_sampling_frequency=dataset_sampling_frequency,
    )

    # check that the base dataset returned by ds_function is FINITE
    for ds in [ds_train, ds_validation]:
        assert tf.cast(tf.data.experimental.cardinality(ds), tf.int32) != tf.data.experimental.INFINITE_CARDINALITY

    # training dataset
    if training_shuffle_buffer_size > 0:
        # Lets try to cache before the shuffle. This is the entire training dataset so we'll cache
        # to memory
        ds_train = (
            ds_train.shuffle(training_shuffle_buffer_size, reshuffle_each_iteration=True)
            .repeat()
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE)
        )

    # validation dataset
    ds_validation = (
        ds_validation.repeat()
        .batch(batch_size)
        .take(validation_steps)
        .cache(f"validation_set_cache_{str(uuid.uuid4())}")
        .repeat()
        .prefetch(tf.data.AUTOTUNE)
    )

    return ds_train, ds_validation


def prepare_visualization_datasets(
    datasets_base_path: str,
    batch_size: int,
    validation_steps: int,
    datasets_to_use: List[str],
    dataset_sampling_frequency: np.ndarray,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """
    Return a training and a testing dataset for visualization

    Args:
        datasets_base_path: path to tfrecord datasets for input data
        batch_size: batch size for training and validation
        validation_steps: number of batches to use for validation
        datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset
        dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
            are sampled from during training / validation dataset creation.

    Returns:
        training and validation datasets derived from the underlying tfrecord data
    """

    assert batch_size > 0
    assert validation_steps is not None and validation_steps > 0

    ds_train = sample_datasets(
        Split.train,
        datasets_base_path,
        datasets=datasets_to_use,
        dataset_sampling_frequency=dataset_sampling_frequency,
        n_samples_per_track=1,
    )
    ds_validation = sample_datasets(
        Split.validation,
        datasets_base_path,
        datasets=datasets_to_use,
        dataset_sampling_frequency=dataset_sampling_frequency,
        n_samples_per_track=1,
    )

    # check that the base dataset returned by ds_function is FINITE
    for ds in [ds_train, ds_validation]:
        assert tf.cast(tf.data.experimental.cardinality(ds), tf.int32) != tf.data.experimental.INFINITE_CARDINALITY

    # training dataset
    ds_train = ds_train.repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)

    # validation dataset
    ds_validation = (
        ds_validation.repeat()
        .batch(batch_size)
        .take(validation_steps)
        .cache(f"validation_set_cache_{str(uuid.uuid4())}")
        .repeat()
        .prefetch(tf.data.AUTOTUNE)
    )

    return ds_train, ds_validation


def sample_datasets(
    split: Split,
    datasets_base_path: str,
    datasets: List[str],
    dataset_sampling_frequency: np.ndarray,
    n_shuffle: int = 1000,
    n_samples_per_track: int = N_SAMPLES_PER_TRACK,
    pairs: bool = False,
) -> tf.data.Dataset:
    """samples tfrecord data to create a dataset

    Args:
        split: whether to use training or validation data
        dataset_base_path: directory storing source datasets as tfrecord files
        datasets: names of datasets to sample from e.g. guitarset
        dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
            are sampled from during training / validation dataset creation.
        n_shuffle: size of shuffle buffer (only used for training ds)
        n_samples_per_track: the number of samples to take from a track
        pairs: generate pairs of samples from the dataset rather than individual samples

    Returns
        dataset of samples
    """
    if split == Split.validation:
        n_shuffle = 0
        pairs = False
        if n_samples_per_track != 1:
            n_samples_per_track = 5

    ds_list = []

    file_generator, random_seed = transcription_file_generator(
        split,
        datasets,
        datasets_base_path,
        dataset_sampling_frequency,
    )

    ds_dataset = transcription_dataset(file_generator, n_samples_per_track, random_seed)

    if n_shuffle > 0:
        ds_dataset = ds_dataset.shuffle(n_shuffle)
    ds_list.append(ds_dataset)

    if pairs:
        pairs_generator, random_seed_pairs = transcription_file_generator(
            split,
            datasets,
            datasets_base_path,
            dataset_sampling_frequency,
        )
        pairs_ds = transcription_dataset(
            pairs_generator,
            n_samples_per_track,
            random_seed_pairs,
        )

        pairs_ds = pairs_ds.shuffle(n_samples_per_track * 10)  # shuffle so that different tracks get mixed together
        pairs_ds = pairs_ds.batch(2)
        pairs_ds = pairs_ds.map(combine_transcription_examples)
        ds_list.append(pairs_ds)

    n_datasets = len(ds_list)
    choice_dataset = tf.data.Dataset.range(
        n_datasets
    ).repeat()  # this repeat is critical! if not, only n_dataset points will be sampled!!
    return tf.data.Dataset.choose_from_datasets(ds_list, choice_dataset)


def transcription_file_generator(
    split: Split,
    dataset_names: List[str],
    datasets_base_path: str,
    sample_weights: np.ndarray,
) -> Tuple[Callable[[], Iterator[tf.Tensor]], bool]:
    """Reads underlying files and returns file generator

    Args:
        split: data split to build generator from
        dataset_names: list of dataset_names to use
        dataset_base_path: directory storing source datasets as tfrecord files
        sample_weights: distribution weighting vector corresponding to datasets determining how they
            are sampled from during training / validation dataset creation.
    """
    file_dict = {
        dataset_name: tf.data.Dataset.list_files(
            os.path.join(datasets_base_path, dataset_name, "splits", split.name, "*tfrecord")
        )
        for dataset_name in dataset_names
    }

    if split == Split.train:
        return lambda: _train_file_generator(file_dict, sample_weights), False
    return lambda: _validation_file_generator(file_dict), True


def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]:
    """file generator for training sets"""
    x = {k: list(v) for (k, v) in x.items()}
    keys = list(x.keys())
    # shuffle each list
    for k in keys:
        np.random.shuffle(x[k])

    while all(x.values()):
        # choose a random dataset and yield the last file
        fpath = x[np.random.choice(keys, p=weights)].pop()
        yield fpath


def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Tensor]:
    """file generator for validation sets"""
    x = {k: list(v) for (k, v) in x.items()}
    # loop until there are no more test files
    while any(x.values()):
        # alternate between datasets (dataset 1 elt 1, dataset 2, elt 1, ...)
        # this is so test files in the tensorboard have 4 different datasets
        # instead of 4 elements from 1
        for k in x:
            # if the list of files for this dataset is empty skip it
            if x[k]:
                yield x[k].pop()


def combine_transcription_examples(
    a: tf.Tensor, target: Dict[str, tf.Tensor], w: Dict[str, tf.Tensor]
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], Dict[str, tf.Tensor]]:
    """mix pairs together for paired dataset

    Args:
        a: audio data
        target: target data (onset, notes, contours)
        w: weights
    """
    return (
        # mix the audio snippets
        tf.math.reduce_mean(a, axis=0),
        # annotations are the max per bin - active frames stay active
        {
            "onset": tf.math.reduce_max(target["onset"], axis=0),
            "contour": tf.math.reduce_max(target["contour"], axis=0),
            "note": tf.math.reduce_max(target["note"], axis=0),
        },
        # weights are the minimum - if an annotation is missing in one, we should set the weights to zero
        {
            "onset": tf.math.reduce_min(w["onset"], axis=0),
            "contour": tf.math.reduce_min(w["contour"], axis=0),
            "note": tf.math.reduce_min(w["note"], axis=0),
        },
    )


def transcription_dataset(
    file_generator: Callable[[], Iterator[str]], n_samples_per_track: int, random_seed: bool
) -> tf.data.Dataset:
    """
    `fpaths_in` is a list of .tfrecords files
    return a tf.Dataset with the following fields (as tuple):
        - audio (shape AUDIO_N_SAMPLES, 1)
        - {'contours': contours, 'notes': notes, 'onsets': onsets}
    contours has shape (ANNOT_N_FRAMES, N_FREQ_BINS_CONTOURS)
    notes and onsets have shape: (ANNOT_N_FRAMES, N_FREQ_BINS_NOTES)
    """
    ds = tf.data.Dataset.from_generator(file_generator, output_types=tf.string, output_shapes=())
    ds = tf.data.TFRecordDataset(ds)
    ds = ds.map(parse_transcription_tfexample, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(is_not_bad_shape)
    ds = ds.map(
        lambda file_id, source, audio_wav, notes_indices, notes_values, onsets_indices, onsets_values, contours_indices, contours_values, notes_onsets_shape, contours_shape: (  # noqa: E501
            file_id,
            source,
            tf.audio.decode_wav(
                audio_wav,
                desired_channels=AUDIO_N_CHANNELS,
                desired_samples=-1,
                name=None,
            ),
            sparse2dense(notes_values, notes_indices, notes_onsets_shape),
            sparse2dense(onsets_values, onsets_indices, notes_onsets_shape),
            sparse2dense(contours_values, contours_indices, contours_shape),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    ds = ds.map(reduce_transcription_inputs)
    ds = ds.map(get_sample_weights, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.flat_map(
        lambda a, o, c, n, ow, cw, nw, m: get_transcription_chunks(
            a, o, c, n, ow, cw, nw, n_samples_per_track, random_seed
        )
    )
    ds = ds.filter(is_not_all_silent_annotations)  # remove examples where all annotations are zero
    ds = ds.map(to_transcription_training_input)
    ds = ds.apply(tf.data.experimental.ignore_errors(log_warning=True))  # failsafe so training doesn't stop
    return ds


def parse_transcription_tfexample(
    serialized_example: tf.train.Example,
) -> Tuple[
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
    tf.Tensor,
]:
    """
    return a tuple with the following tensors, in order:
     - file_id
     - source
     - audio_wav
     - notes_indices
     - notes_values
     - onsets_indices
     - onsets_values
     - contours_indices
     - contours_values
     - notes_onsets_shape
     - contours_shape
    NB.: notes, onsets and contours are represented as sparse matrices
    (to be reconstructed using `tf.SparseTensor(...)`). They share the
    time dimension, while contours have a frequency dimension that is
    a multiple (`ANNOTATIONS_BINS_PER_SEMITONE`) of that of
    notes/onsets.
    """
    schema = {
        "file_id": tf.io.FixedLenFeature((), tf.string),
        "source": tf.io.FixedLenFeature((), tf.string),
        "audio_wav": tf.io.FixedLenFeature((), tf.string),
        "notes_indices": tf.io.FixedLenFeature((), tf.string),
        "notes_values": tf.io.FixedLenFeature((), tf.string),
        "onsets_indices": tf.io.FixedLenFeature((), tf.string),
        "onsets_values": tf.io.FixedLenFeature((), tf.string),
        "contours_indices": tf.io.FixedLenFeature((), tf.string),
        "contours_values": tf.io.FixedLenFeature((), tf.string),
        "notes_onsets_shape": tf.io.FixedLenFeature((), tf.string),
        "contours_shape": tf.io.FixedLenFeature((), tf.string),
    }
    example = tf.io.parse_single_example(serialized_example, schema)
    return (
        example["file_id"],
        example["source"],
        example["audio_wav"],
        tf.io.parse_tensor(example["notes_indices"], out_type=tf.int64),
        tf.io.parse_tensor(example["notes_values"], out_type=tf.float32),
        tf.io.parse_tensor(example["onsets_indices"], out_type=tf.int64),
        tf.io.parse_tensor(example["onsets_values"], out_type=tf.float32),
        tf.io.parse_tensor(example["contours_indices"], out_type=tf.int64),
        tf.io.parse_tensor(example["contours_values"], out_type=tf.float32),
        tf.io.parse_tensor(example["notes_onsets_shape"], out_type=tf.int64),
        tf.io.parse_tensor(example["contours_shape"], out_type=tf.int64),
    )


def is_not_bad_shape(
    _file_id: tf.Tensor,
    _source: tf.Tensor,
    _audio_wav: tf.Tensor,
    _notes_indices: tf.Tensor,
    notes_values: tf.Tensor,
    _onsets_indices: tf.Tensor,
    _onsets_values: tf.Tensor,
    _contours_indices: tf.Tensor,
    _contours_values: tf.Tensor,
    notes_onsets_shape: tf.Tensor,
    _contours_shape: tf.Tensor,
) -> tf.Tensor:
    """checks for improper datashape for note values and onsets"""
    bad_shape = tf.logical_and(
        tf.shape(notes_values)[0] == 0,
        tf.shape(notes_onsets_shape)[0] == 2,
    )
    return tf.math.logical_not(bad_shape)


def sparse2dense(values: tf.Tensor, indices: tf.Tensor, dense_shape: tf.Tensor) -> tf.Tensor:
    """converts sparse tensor representation to dense vector"""
    if tf.rank(indices) != 2 and tf.size(indices) == 0:
        indices = tf.zeros([0, 1], dtype=indices.dtype)
    tf.assert_rank(indices, 2)
    tf.assert_rank(values, 1)
    tf.assert_rank(dense_shape, 1)
    sp = tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
    return tf.sparse.to_dense(sp, validate_indices=False)


def reduce_transcription_inputs(
    file_id: str,
    src: str,
    wav: Tuple[tf.Tensor, int],
    notes: tf.Tensor,
    onsets: tf.Tensor,
    contour: tf.Tensor,
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, Dict[str, str]]:
    """Map tf records data to a tuple
    If audio is stereo, it is mixed down to mono.
    This will error if the sample rate of the wav file is different from
    what we hard code.
    Args:
        fid : file id (string)
        src : name of dataset (string)
        wav : tensorflow wav obejct (tuple of audio and sample rate)
            The whole audio file length
        notes : matrix of note frames (n_frames x N_FREQ_BINS_NOTES)
            possibly size 0
        onsets : matrix of note onsets (n_frames x N_FREQ_BINS_NOTES)
            possibly size 0
        contours : matrix of contour frames (n_frames x N_FREQ_BINS_CONTOURS)
            possibly size 0
    """
    audio, sample_rate = wav
    tf.debugging.assert_equal(
        sample_rate,
        AUDIO_SAMPLE_RATE,
        message="audio sample rate {} is inconsistent".format(sample_rate),
    )
    return (
        tf.math.reduce_mean(audio, axis=1, keepdims=True),  # manually mixdown to mono
        onsets,
        contour,
        notes,
        {"fid": file_id, "src": src},
    )


def _infer_time_size(onsets: tf.Tensor, contour: tf.Tensor, notes: tf.Tensor) -> tf.Tensor:
    """Some of the targets might be empty, but we need to find out the
    number of time frames of one of the non-empty ones.
    Returns
        number of time frames in the targets
    """
    onset_shape = tf.shape(onsets)[0]
    contour_shape = tf.shape(contour)[0]
    note_shape = tf.shape(notes)[0]
    time_size = tf.cast(
        tf.math.maximum(
            tf.cast(tf.math.maximum(onset_shape, contour_shape), dtype=tf.int32),
            note_shape,
        ),
        dtype=tf.int32,
    )

    return time_size


def get_sample_weights(
    audio: tf.Tensor, onsets: np.ndarray, contour: np.ndarray, notes: np.ndarray, metadata: Dict[Any, Any]
) -> Tuple[tf.Tensor, np.ndarray, np.ndarray, np.ndarray, tf.cond, tf.cond, tf.cond, Dict[Any, Any]]:
    """Add sample weights based on whether or not the target is empty
    If it's empty, the weight is 0, otherwise it's 1. Empty targets get filled
    with matricies of 0's
    Args:
        audio : audio signal (full length)
        notes : matrix of note frames (n_frames x N_FREQ_BINS_NOTES)
            possibly size 0
        onsets : matrix of note onsets (n_frames x N_FREQ_BINS_NOTES)
            possibly size 0
        contours : matrix of contour frames (n_frames x N_FREQ_BINS_CONTOURS)
            possibly size 0
        metadata : dictionary of metadata
    Returns:
        audio : audio signal (full length)
        notes : matrix of note frames (n_frames x N_FREQ_BINS_NOTES)
        onsets : matrix of note onsets (n_frames x N_FREQ_BINS_NOTES)
        contours : matrix of contour frames (n_frames x N_FREQ_BINS_CONTOURS)
        onset_weight : int (0 or 1)
        note_weight : int (0 or 1)
        contour_weight : int (0 or 1)
    """
    time_size = _infer_time_size(onsets, contour, notes)

    # TODO - if we dont want to worry about batches with no examples for a task
    # we can add a tiny constant here, but training will be unstable
    onsets_weight = tf.cast(tf.shape(onsets)[0] != 0, tf.float32)
    contour_weight = tf.cast(tf.shape(contour)[0] != 0, tf.float32)
    note_weight = tf.cast(tf.shape(notes)[0] != 0, tf.float32)

    onsets = tf.cond(
        tf.shape(onsets)[0] == 0,
        lambda: tf.zeros(
            tf.stack([time_size, tf.constant(N_FREQ_BINS_NOTES, dtype=tf.int32)], axis=0),
            dtype=tf.float32,
        ),
        lambda: onsets,
    )
    contour = tf.cond(
        tf.shape(contour)[0] == 0,
        lambda: tf.zeros(
            tf.stack([time_size, tf.constant(N_FREQ_BINS_CONTOURS, dtype=tf.int32)], axis=0),
            dtype=tf.float32,
        ),
        lambda: contour,
    )
    notes = tf.cond(
        tf.shape(notes)[0] == 0,
        lambda: tf.zeros(
            tf.stack([time_size, tf.constant(N_FREQ_BINS_NOTES, dtype=tf.int32)], axis=0),
            dtype=tf.float32,
        ),
        lambda: notes,
    )

    return (
        audio,
        onsets,
        contour,
        notes,
        onsets_weight,
        contour_weight,
        note_weight,
        metadata,
    )


def trim_time(data: np.ndarray, start: int, duration: int, sr: int) -> tf.Tensor:
    """
    Slice a data file
    Args:
        data: 2D data as (n_time_samples, n_channels) array
            can be audio or a time-frequency matrix
        start: trim start time in seconds
        duration: trim duration in seconds
        sr: data sample rate
    Returns:
        sliced_data (tf.tensor): (trimmed_time, n_channels)
    """
    n_start = tf.cast(tf.math.round(sr * start), dtype=tf.int32)
    n_duration = tf.cast(tf.math.ceil(tf.cast(sr * duration, tf.float32)), dtype=tf.int32)
    begin = (n_start, 0)
    size = (n_duration, -1)
    return tf.slice(data, begin=begin, size=size)


def extract_window(
    audio: tf.Tensor, onsets: np.ndarray, contour: np.ndarray, notes: np.ndarray, t_start: int
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
    """extracts a window of data from the given audio and its associated metadata

    Args:
        audio: audio signal
        onsets: note onsets of audio signal
        contour: pitch on off of audio signal
        notes: note on off of audio signal

    Returns:
        tuple of windows of each of the inputs
    """
    # needs a hop size extra of samples for good mel spectrogram alignment
    audio_trim = trim_time(
        audio,
        t_start,
        tf.cast(AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE, dtype=tf.dtypes.float32),
        AUDIO_SAMPLE_RATE,
    )
    onset_trim = trim_time(onsets, t_start, AUDIO_WINDOW_LENGTH, ANNOTATIONS_FPS)
    contour_trim = trim_time(contour, t_start, AUDIO_WINDOW_LENGTH, ANNOTATIONS_FPS)
    note_trim = trim_time(notes, t_start, AUDIO_WINDOW_LENGTH, ANNOTATIONS_FPS)
    return (audio_trim, onset_trim, contour_trim, note_trim)


def extract_random_window(
    audio: tf.Tensor, onsets: np.ndarray, contour: np.ndarray, notes: np.ndarray, seed: Optional[int]
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
    """Trim transcription data to a fixed length of time
    starting from a random time index.
    Args:
        audio : audio signal (full length)
        notes : matrix of note frames (n_frames x N_FREQ_BINS_NOTES)
        onsets : matrix of note onsets (n_frames x N_FREQ_BINS_NOTES)
        contours : matrix of contour frames (n_frames x N_FREQ_BINS_CONTOURS)
    Returns:
        audio : audio signal (AUDIO_WINDOW_LENGTH * AUDIO_SAMPLE_RATE, 1)
        notes : matrix of note frames (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_NOTES)
        onsets : matrix of note onsets (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_NOTES)
        contours : matrix of contour frames (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_CONTOURS)
    """
    n_sec = tf.math.divide(
        tf.cast(tf.shape(audio)[0], dtype=tf.float32),
        tf.cast(AUDIO_SAMPLE_RATE, dtype=tf.float32),
    )
    t_start = tf.random.uniform(
        (),
        minval=0.0,
        maxval=n_sec - (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE),
        dtype=tf.dtypes.float32,
        seed=seed,
    )
    return extract_window(audio, onsets, contour, notes, t_start)


def get_transcription_chunks(
    audio: tf.Tensor,
    onsets: np.ndarray,
    contour: np.ndarray,
    notes: np.ndarray,
    onset_weight: int,
    contour_weight: int,
    note_weight: int,
    n_samples_per_track: int,
    seed: bool,
) -> tf.data.Dataset:
    """Randomly sample fixed-length time chunks for transcription data
    Args:
        audio : audio signal (full length)
        notes : matrix of note frames (n_frames x N_FREQ_BINS_NOTES)
        onsets : matrix of note onsets (n_frames x N_FREQ_BINS_NOTES)
        contours : matrix of contour frames (n_frames x N_FREQ_BINS_CONTOURS)
        onset_weight : int (0 or 1)
        note_weight : int (0 or 1)
        contour_weight : int (0 or 1)
        n_samples_per_track : int - how many samples to yield per track
    Returns:
        batches of size n_samples_per_track of:
            audio : audio signal (AUDIO_WINDOW_LENGTH * AUDIO_SAMPLE_RATE, 1)
            notes : matrix of note frames (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_NOTES)
            onsets : matrix of note onsets (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_NOTES)
            contours : matrix of contour frames (AUDIO_SAMPLE_RATE * ANNOTATIONS_FPS, N_FREQ_BINS_CONTOURS)
            onset_weight : int (0 or 1)
            note_weight : int (0 or 1)
            contour_weight : int (0 or 1)
    """
    a = []
    o = []
    c = []
    n = []
    ow = []
    cw = []
    nw = []
    for i in range(n_samples_per_track):
        s0, s1, s2, s3 = extract_random_window(audio, onsets, contour, notes, i if seed else None)
        a.append(s0)
        o.append(s1)
        c.append(s2)
        n.append(s3)
        ow.append(onset_weight)
        cw.append(contour_weight)
        nw.append(note_weight)
    return tf.data.Dataset.from_tensor_slices((a, o, c, n, ow, cw, nw))


def is_not_all_silent_annotations(
    audio: tf.Tensor,
    onsets: np.ndarray,
    contour: np.ndarray,
    notes: np.ndarray,
    onset_weight: int,
    contour_weight: int,
    note_weight: int,
) -> tf.Tensor:
    """returns a boolean value indicating whether the notes and pitch contour are or are not all zero, or silent."""
    contours_nonsilent = tf.math.reduce_mean(contour) != 0
    notes_nonsilent = tf.math.reduce_mean(notes) != 0
    return tf.math.logical_or(contours_nonsilent, notes_nonsilent)


def to_transcription_training_input(
    audio: tf.Tensor,
    onsets: np.ndarray,
    contour: np.ndarray,
    notes: np.ndarray,
    onset_weight: int,
    contour_weight: int,
    note_weight: int,
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], Dict[str, int]]:
    """convert transcription data to the format expected by the model"""
    return (
        audio,
        {
            "onset": tf.ensure_shape(onsets, (ANNOT_N_FRAMES, N_FREQ_BINS_NOTES)),
            "contour": tf.ensure_shape(contour, (ANNOT_N_FRAMES, N_FREQ_BINS_CONTOURS)),
            "note": tf.ensure_shape(notes, (ANNOT_N_FRAMES, N_FREQ_BINS_NOTES)),
        },
        {"onset": onset_weight, "contour": contour_weight, "note": note_weight},
    )
