import argparse
import os
import torch
import json
from tensorrt_llm.runtime import ModelRunnerCpp
from tensorrt_llm.bindings import GptJsonConfig
import numpy as np
from collections import OrderedDict
from pathlib import Path
from whisper_utils import log_mel_spectrogram, get_tokenizer
import evaluate
from normalizer import data_utils
import time
from tqdm import tqdm
from pathlib import Path
import re
from concurrent.futures import ThreadPoolExecutor

wer_metric = evaluate.load("wer")

def read_config(component, engine_dir):
    engine_dir = Path(engine_dir)
    config_path = engine_dir / component / 'config.json'
    with open(config_path, 'r') as f:
        config = json.load(f)
    model_config = OrderedDict()
    model_config.update(config['pretrained_config'])
    model_config.update(config['build_config'])
    return model_config

class WhisperTRTLLM(object):

    def __init__(self,
                 engine_dir,
                 assets_dir="assets",
                 batch_size=64):
        encoder_config = read_config('encoder', engine_dir)
        decoder_config = read_config('decoder', engine_dir)
        self.n_mels = encoder_config['n_mels']
        self.num_languages = encoder_config['num_languages']
        is_multilingual = (decoder_config['vocab_size'] >= 51865)
        if is_multilingual:
            tokenizer_name = "multilingual"
            assert (Path(assets_dir) / "multilingual.tiktoken").exists(
            ), "multilingual.tiktoken file is not existed in assets_dir"
        else:
            tokenizer_name = "gpt2"
            assert (Path(assets_dir) / "gpt2.tiktoken").exists(
            ), "gpt2.tiktoken file is not existed in assets_dir"
        self.text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" if is_multilingual else "<|startoftranscript|><|notimestamps|>"
        self.tokenizer = get_tokenizer(name=tokenizer_name,
                                       num_languages=self.num_languages,
                                       tokenizer_dir=assets_dir)
        self.eot_id = self.tokenizer.encode(
            "<|endoftext|>",
            allowed_special=self.tokenizer.special_tokens_set)[0]
        json_config = GptJsonConfig.parse_file(Path(engine_dir) / 'decoder' / 'config.json')
        assert json_config.model_config.supports_inflight_batching
        runner_kwargs = dict(engine_dir=engine_dir,
                                is_enc_dec=True,
                                max_batch_size=batch_size,
                                max_input_len=3000,
                                max_output_len=96,
                                max_beam_width=1,
                                debug_mode=False,
                                kv_cache_free_gpu_memory_fraction=0.9)
        self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)

    def process_single_batch(self, mel_batch, decoder_input_ids, mel_input_lengths, max_new_tokens):
        outputs = self.model_runner_cpp.generate(
            batch_input_ids=decoder_input_ids,
            encoder_input_features=mel_batch,
            encoder_output_lengths=mel_input_lengths // 2,
            max_new_tokens=max_new_tokens,
            end_id=self.eot_id,
            pad_id=self.eot_id,
            num_beams=1,
            output_sequence_lengths=True,
            return_dict=True
        )
        
        output_ids = outputs['output_ids'].cpu().numpy().tolist()
        texts = []
        for i in range(len(output_ids)):
            text = self.tokenizer.decode(output_ids[i][0]).strip()
            text = re.sub(r'<\|.*?\|>', '', text)
            texts.append(text)
        return texts
    
    def process_batch(self, mel, mel_input_lengths, num_threads=4, max_new_tokens=96):
        prompt_id = self.tokenizer.encode(
            self.text_prefix, allowed_special=self.tokenizer.special_tokens_set)
        prompt_id = torch.tensor(prompt_id)
        batch_size = len(mel)
        decoder_input_ids = prompt_id.repeat(batch_size, 1)

        with torch.no_grad():
            if isinstance(mel, list):
                mel = torch.stack([m.transpose(1, 2).type(torch.float16).squeeze(0) for m in mel])
            else:
                mel = mel.transpose(1, 2)

            num_threads = min(num_threads, batch_size)
            mel_batches = torch.split(mel, batch_size // num_threads)
            mel_input_lengths_batches = torch.split(mel_input_lengths, batch_size // num_threads)

            texts_list = []
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                futures = []
                for i, mel_batch in enumerate(mel_batches):
                    current_length = mel_batch.size(0)
                    futures.append(executor.submit(
                        self.process_single_batch,
                        mel_batch,
                        decoder_input_ids[:current_length],
                        mel_input_lengths_batches[i],
                        max_new_tokens
                    ))
                
                for future in futures:
                    texts_list.extend(future.result())
        
        return texts_list

def longest_common_substring(s1, s2):
    len1, len2 = len(s1), len(s2)
    dp = [[0] * (len2 + 1) for _ in range(len1 + 1)]
    
    longest_length = 0  
    end_index_s1 = 0 

    for i in range(1, len1 + 1):
        for j in range(1, len2 + 1):
            if s1[i - 1] == s2[j - 1]: 
                dp[i][j] = dp[i - 1][j - 1] + 1
                if dp[i][j] > longest_length:
                    longest_length = dp[i][j]
                    end_index_s1 = i  
            else:
                dp[i][j] = 0 

    return s1[end_index_s1 - longest_length:end_index_s1]

def chunk_audio(audio, chunk_length, overlap_length, sample_rate):
    chunk_size = int(chunk_length * sample_rate)
    overlap_size = int(overlap_length * sample_rate)
    
    chunks = []
    start = 0
    
    while start < len(audio):
        end = min(start + chunk_size, len(audio))
        chunks.append(audio[start:end])
        start += chunk_size - overlap_size
    
    return chunks

def main(args):
    asr_model = WhisperTRTLLM(engine_dir=args.model_id)

    def benchmark(batch, min_new_tokens=None):
        # Load audio inputs
        max_duration, sample_rate = 30, 16000
        audios_origin = [audio["array"].astype(np.float32) for audio in batch["audio"]]
        minibatch_size = len(audios_origin)
        audios, audio_index = [], []

        chunk_length = 25
        overlap_length = 5 
        for i, audio in enumerate(audios_origin):
            if len(audio) > max_duration * sample_rate:
                audio_chunks = chunk_audio(audio, chunk_length, overlap_length, sample_rate)
                for chunk in audio_chunks:
                    audios.append(chunk)
                    audio_index.append(i)
            else:
                audios.append(audio)
                audio_index.append(i)
        audios = [torch.from_numpy(audio) for audio in audios]

        # START TIMING
        start_time = time.time()
        longest_duration = int(sample_rate * max_duration)

        features = [
            log_mel_spectrogram(wave,
                                asr_model.n_mels,
                                padding=longest_duration - wave.shape[-1],
                                device='cuda').unsqueeze(0)
            for wave in audios
        ]

        features_input_lengths = torch.tensor([f.shape[2] for f in features],
                                              dtype=torch.int32,
                                              device='cuda')

        texts_origin = asr_model.process_batch(features, features_input_lengths, num_threads=4)

        texts = []
        for i in range(minibatch_size):
            text_chunks = []
            for j in range(len(texts_origin)):
                if audio_index[j] == i:
                    text_chunks.append(texts_origin[j])
            
            if len(text_chunks) > 1:
                merged_text = text_chunks[0]
                for t in text_chunks[1:]:
                    lcs = longest_common_substring(merged_text, t)
                    merged_text += t[len(lcs):]
                    
                texts.append(merged_text)
            else:
                texts.append(text_chunks[0])
        # END TIMING
        runtime = time.time() - start_time

        print(f"Batch size: {minibatch_size}, Time taken: {runtime:.2f} s, texts_origin_len: {len(texts_origin)}, texts_len: {len(texts)}")
        # normalize by minibatch size since we want the per-sample time
        batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]

        # normalize transcriptions with English normalizer
        batch["predictions"] = [data_utils.normalizer(pred) for pred in texts]
        batch["references"] = batch["norm_text"]
        return batch

    if args.warmup_steps is not None:
        dataset = data_utils.load_data(args)
        dataset = data_utils.prepare_data(dataset)

        num_warmup_samples = args.warmup_steps * args.batch_size
        if args.streaming:
            warmup_dataset = dataset.take(num_warmup_samples)
        else:
            warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
        warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))

        for _ in tqdm(warmup_dataset, desc="Warming up..."):
            continue

    dataset = data_utils.load_data(args)
    if args.max_eval_samples is not None and args.max_eval_samples > 0:
        print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
        if args.streaming:
            dataset = dataset.take(args.max_eval_samples)
        else:
            dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
    dataset = data_utils.prepare_data(dataset)

    dataset = dataset.map(
        benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
    )

    all_results = {
        "audio_length_s": [],
        "transcription_time_s": [],
        "predictions": [],
        "references": [],
    }
    result_iter = iter(dataset)
    for result in tqdm(result_iter, desc="Samples..."):
        for key in all_results:
            all_results[key].append(result[key])

    # Write manifest results (WER and RTFX)
    manifest_path = data_utils.write_manifest(
        all_results["references"],
        all_results["predictions"],
        args.model_id,
        args.dataset_path,
        args.dataset,
        args.split,
        audio_length=all_results["audio_length_s"],
        transcription_time=all_results["transcription_time_s"],
    )
    print("Results saved at path:", os.path.abspath(manifest_path))

    wer = wer_metric.compute(
        references=all_results["references"], predictions=all_results["predictions"]
    )
    wer = round(100 * wer, 2)
    rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
    print("WER:", wer, "%", "RTFx:", rtfx)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_id",
        type=str,
        required=True,
        help="Model identifier. Should be loadable with 🤗 Transformers",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="esb/datasets",
        help="Dataset path. By default, it is `esb/datasets`",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
        "can be found at `https://huggingface.co/datasets/esb/datasets`",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="test",
        help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
    )
    parser.add_argument(
        "--device",
        type=int,
        default=-1,
        help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Number of samples to go through each streamed batch.",
    )
    parser.add_argument(
        "--max_eval_samples",
        type=int,
        default=None,
        help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
    )
    parser.add_argument(
        "--no-streaming",
        dest="streaming",
        action="store_false",
        help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=None,
        help="Maximum number of tokens to generate (for auto-regressive models).",
    )
    parser.add_argument(
        "--warmup_steps",
        type=int,
        default=10,
        help="Number of warm-up steps to run before launching the timed runs.",
    )
    args = parser.parse_args()
    parser.set_defaults(streaming=False)

    main(args)
