# Copyright 2020 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import io
import os
import pickle

import apache_beam as beam
from apache_beam import pvalue

import librosa
import numpy as np

from klio_audio import decorators
from klio_core.proto import klio_pb2
from klio.transforms import decorators as tfm_decorators



####
# Helper funcs for handling klio & numpy de/serialization when working
# with pcolls that are grouped by key
# TODO: it'd prob be helpful to provide utilities for making use of
# beam.(Co)GroupByKey
####
def _load_from_msg(item):
    kmsg = klio_pb2.KlioMessage()
    kmsg.ParseFromString(item)
    return pickle.loads(kmsg.data.payload)


def _dump_to_klio_message(key, payload):
    kmsg = klio_pb2.KlioMessage()
    kmsg.data.element = key
    out = io.BytesIO()
    np.save(out, payload)
    kmsg.data.payload = out.getvalue()
    return kmsg.SerializeToString()


# Transforms
class GetMagnitude(beam.DoFn):
    @tfm_decorators._handle_klio
    @decorators.handle_binary(load_with_numpy=True)
    def process(self, item):
        element = item.element.decode("utf-8")
        self._klio.logger.debug(
            "Computing the magnitude spectrogram for {}".format(element)
        )
        stft = item.payload
        spectrogram, phase = librosa.magphase(stft)
        # yield "phase" to show multi-yields w/ tagged outputswork, but
        # we're only concerned about the spectrogram in our integration
        # test pipeline
        yield pvalue.TaggedOutput("phase", spectrogram)
        yield pvalue.TaggedOutput("spectrogram", spectrogram)


class FilterNearestNeighbors(beam.DoFn):
    @tfm_decorators._handle_klio
    @decorators.handle_binary
    def process(self, item):
        element = item.element.decode("utf-8")
        self._klio.logger.debug(
            "Filtering nearest neighbors for {}".format(element)
        )
        spectrogram = item.payload
        nn_filter = librosa.decompose.nn_filter(
            spectrogram,
            aggregate=np.median,
            metric="cosine",
            width=int(librosa.time_to_frames(2)),
        )

        # The output of the filter shouldn't be greater than the input
        # if we assume signals are additive.  Taking the pointwise minimium
        # with the input spectrum forces this.
        nn_filter = np.minimum(spectrogram, nn_filter)
        yield nn_filter


# TODO: this could be useful enough to make a generic "group by klio element"
def create_key_from_element(item):
    kmsg = klio_pb2.KlioMessage()
    kmsg.ParseFromString(item)
    return (kmsg.data.element, item)


def subtract_filter_from_full(key_pair):
    # key_pair looks like
    # (element, {"full": [<serialized numpy array>], "nnfilter": [<serialized numpy array>]})
    key, pair_data = key_pair
    full = _load_from_msg(pair_data["full"][0])
    nn_filter = _load_from_msg(pair_data["nnfilter"][0])

    net = full - nn_filter
    payload = pickle.dumps(net)
    kmsg = klio_pb2.KlioMessage()
    kmsg.data.element = key
    kmsg.data.payload = payload

    return (key, kmsg.SerializeToString())


class GetSoftMask(beam.DoFn):
    def __init__(self, margin=1, power=2):
        self.margin = margin
        self.power = power

    @tfm_decorators._set_klio_context
    def process(self, item):
        key, data = item
        first_data = data["first"][0]
        second_data = data["second"][0]
        full_data = data["full"][0]

        first = _load_from_msg(first_data)
        second = _load_from_msg(second_data)
        full = _load_from_msg(full_data)

        self._klio.logger.debug("Getting softmask for {}".format(key))
        mask = librosa.util.softmask(
            first, self.margin * second, power=self.power
        )
        ret = mask * full
        yield _dump_to_klio_message(key, ret)
