"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
"""


from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import os
import random
from collections import namedtuple

import sox

Speaker = namedtuple("Speaker", ["id", "gender"])
FileRecord = namedtuple("FileRecord", ["fid", "length", "speaker"])


def split_audio(line):
    apath, meetid, hset, spk, start, end, transcript = line.strip().split(" ", 6)
    key = "_".join([meetid, hset, spk, start, end])
    os.makedirs(os.path.join(apath, "segments", meetid), exist_ok=True)
    idx = hset[-1]
    fn = f"{meetid}.Headset-{idx}.wav"
    infile = os.path.join(apath, meetid, fn)
    assert os.path.exists(infile), f"{infile} doesn't exist"
    new_path = os.path.join(apath, "segments", meetid, key + ".flac")
    sox_tfm = sox.Transformer()
    sox_tfm.set_output_format(
        file_type="flac", encoding="signed-integer", bits=16, rate=16000
    )
    start = float(start)
    end = float(end)
    sox_tfm.trim(start, end)
    sox_tfm.build(infile, new_path)
    sx_dur = sox.file_info.duration(new_path)
    if sx_dur is not None and abs(sx_dur - end + start) < 0.5:
        return [meetid, key, new_path, str(round(sx_dur * 1000, 2)), transcript.lower()]


def do_split(all_records, spkrs, total_seconds, handles_chosen=None):
    """
    Greedily selecting speakers, provided we don't go over budget
    """
    time_taken = 0.0
    records_filtered = []
    idx = 0
    speakers = copy.deepcopy(spkrs)
    current_speaker_time = {spk: 0 for spk in speakers}
    current_speaker_idx = {spk: 0 for spk in speakers}
    while True:
        if len(speakers) == 0:
            break
        speaker = speakers[idx % len(speakers)]
        idx += 1
        tocontinue = False
        while True:
            cur_spk_idx = current_speaker_idx[speaker]
            if cur_spk_idx == len(all_records[speaker]):
                speakers.remove(speaker)
                tocontinue = True
                break
            cur_record = all_records[speaker][cur_spk_idx]
            current_speaker_idx[speaker] += 1
            if handles_chosen is None or cur_record.fid not in handles_chosen:
                break
        if tocontinue:
            continue
        records_filtered.append(cur_record)
        time_taken += cur_record.length
        current_speaker_time[speaker] += cur_record.length
        if abs(time_taken - total_seconds) < 10:
            break

    return records_filtered, time_taken


def get_speakers(train_file):
    cache = {}
    all_speakers = []
    with open(train_file) as f:
        for line in f:
            spl = line.split()
            speaker_id = spl[0].split("_")[2]
            gender = speaker_id[0]
            if gender not in ["M", "F"]:
                continue
            if speaker_id not in cache:
                cache[speaker_id] = 1
                speaker = Speaker(id=speaker_id, gender=gender)
                all_speakers.append(speaker)
    return all_speakers


def get_fid2length(train_file):
    fids = []
    lengths = []
    with open(train_file) as f:
        for line in f:
            spl = line.split()
            fids.append(spl[0])
            lengths.append(float(spl[2]) / 1000)
    return list(zip(fids, lengths))


def full_records(speakers, fid2length, subset_name=None):
    all_records = []
    speakers = {(speaker.id, speaker) for speaker in speakers}

    for fid, length in fid2length:
        speaker = fid.split("_")[2]
        assert speaker in speakers, f"Unknown speaker! {speaker}"

        speaker = speakers[speaker]

        if subset_name is not None:
            assert subset_name == speaker.subset
        frecord = FileRecord(speaker=speaker, length=length, fid=fid)
        all_records.append(frecord)
    return all_records


def get_speaker2time(records, lambda_key, lambda_value):
    from collections import defaultdict

    key_value = defaultdict(int)

    for record in records:
        key = lambda_key(record)
        value = lambda_value(record)
        key_value[key] += value

    return key_value


def create_limited_sup(list_dir):
    random.seed(0)
    train_file = os.path.join(list_dir, "train.lst")
    assert os.path.exists(train_file)

    speakers = get_speakers(train_file)
    print("Found speakers", len(speakers))

    write_records = {}
    chosen_records = {}

    fid2length = get_fid2length(train_file)
    all_records = full_records(speakers, fid2length)

    for gender in ["M", "F"]:
        print(f"Selecting from gender {gender}")
        records = [rec for rec in all_records if rec.speaker.gender == gender]

        speaker2time = get_speaker2time(
            records, lambda_key=lambda r: r.speaker.id, lambda_value=lambda r: r.length
        )

        # select 15 random speakers
        min_minutes_per_speaker = 15
        speakers_10hr = {
            r.speaker.id
            for r in records
            if speaker2time[r.speaker.id] >= min_minutes_per_speaker * 60
        }
        speakers_10hr = sorted(speakers_10hr)
        random.shuffle(speakers_10hr)
        speakers_10hr = speakers_10hr[:15]

        print(f"Selected speakers from gender {gender} ", speakers_10hr)

        cur_records = {}
        for speaker in speakers_10hr:
            cur_records[speaker] = [r for r in records if r.speaker.id == speaker]
            random.shuffle(cur_records[speaker])

        # 1 hr as 6 x 10min splits
        key = "10min_" + gender
        write_records[key] = {}
        for i in range(6):
            speakers_10min = random.sample(set(speakers_10hr), 3)
            write_records[key][i], _ = do_split(
                cur_records, speakers_10min, 10 * 60 / 2, chosen_records
            )
            for kk in write_records[key][i]:
                chosen_records[kk.fid] = 1

        # 9 hr
        key = "9hr_" + gender
        write_records[key], _ = do_split(
            cur_records, speakers_10hr, (9 * 60 * 60) / 2, chosen_records
        )

    train_lines = {}
    with open(train_file) as f:
        for line in f:
            train_lines[line.split()[0]] = line.strip()

    print("Writing 6 x 10min list files...")
    for i in range(6):
        with open(os.path.join(list_dir, f"train_10min_{i}.lst"), "w") as fo:
            for record in write_records["10min_M"][i] + write_records["10min_F"][i]:
                fo.write(train_lines[record.fid])

    print("Writing 9hr list file...")
    with open(os.path.join(list_dir, "train_9hr.lst"), "w") as fo:
        for record in write_records["9hr_M"] + write_records["9hr_F"]:
            fo.write(train_lines[record.fid])
