"""
Nanotron training script.

Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml
```
"""
import argparse
import time
from pprint import pformat
from typing import Dict, Optional, cast

import nanotron.distributed as dist
from nanotron import logging
from nanotron.config import (
    DataArgs,
    DatasetStageArgs,
    NanosetDatasetsArgs,
    PretrainDatasetsArgs,
    Qwen2Config,
    SFTDatasetsArgs,
)
from nanotron.data.dataloader import (
    dummy_infinite_data_generator,
    get_train_dataloader,
)
from nanotron.data.processing import (
    clm_process,
    get_datasets,
)
from nanotron.data.sft_processing import prepare_sft_dataset
from nanotron.helpers import (
    compute_remain_train_steps_of_a_data_stage_from_ckp,
    get_consumed_train_samples_of_a_data_stage_from_ckp,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.sanity_checks import sanity_check_dataloader
from nanotron.trainer import DistributedTrainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader
from nanotron.trainer import DataStageMetadata
from collections import defaultdict
try:
    from huggingface_hub import __version__ as hf_hub_version
    from transformers import AutoTokenizer
    from transformers import __version__ as tf_version
except ImportError:
    hf_hub_version = None
    tf_version = None

logger = logging.get_logger(__name__)

# import lovely_tensors as lt

# lt.monkey_patch()


def get_dataloader_from_data_stage(
    trainer: DistributedTrainer,
    data: DataArgs,
    consumed_train_samples_stage: int,
    consumed_tokens_per_dataset_folder: Dict[str, int],
    last_stages_consumed_tokens_per_dataset_folder: Dict[str, int],
    num_remaining_train_steps: int,
    sanity_check_dataloader_interval: Optional[int] = None,
):
    """
    Returns a dataloader for a given data stage.

    data: The data configuration for the current stage.
    consumed_train_samples_stage: The number of samples consumed by the model in the this stage (each stage starts from zero).
    consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in previous stages to avoid reseeing them, because the sampler has restarted for this stage.
    num_remaining_train_steps: The number of remaining training steps for this stage.
    """
    assert consumed_train_samples_stage >= 0, "consumed_train_samples_stage should be greater than 0"
    assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0"

    # First, we need to know which ranks to feed the dataloader to
    input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

    # Case 1: Dummy data generator
    if data.dataset is None:
        log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
        dataloader = dummy_infinite_data_generator(
            micro_batch_size=trainer.micro_batch_size,
            sequence_length=trainer.sequence_length,
            input_pp_rank=input_pp_rank,
            output_pp_rank=output_pp_rank,
            vocab_size=trainer.model_config.vocab_size,
            seed=data.seed,
            parallel_context=trainer.parallel_context,
            use_position_ids=isinstance(
                trainer.model_config, Qwen2Config
            ),  # Simulate packed sequences to test SFT or inference
            cp_pg=trainer.parallel_context.cp_pg,
        )()

    # Case 2: HuggingFace datasets
    elif isinstance(data.dataset, PretrainDatasetsArgs) or isinstance(data.dataset, SFTDatasetsArgs):
        log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
        tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
        log_rank(
            f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
            logger=logger,
            level=logging.INFO,
            rank=0,
        )

        # We need to the 1st device to process dataset and cache it, then other devices load from cache
        with main_rank_first(trainer.parallel_context.world_pg):
            # TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
            # TODO: generalise to include  for validation/test splits

            # We load the raw dataset
            raw_dataset = get_datasets(
                hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets,
                hf_dataset_config_name=data.dataset.hf_dataset_config_name,
                splits=data.dataset.hf_dataset_splits,
            )["train"]

            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            tokenizer.padding_side = "left"
            sequence_sep_tokens = [tokenizer.bos_token, tokenizer.eos_token, tokenizer.pad_token, tokenizer.unk_token]
            # assert bos or eos are present
            assert (
                tokenizer.bos_token is not None or tokenizer.eos_token is not None
            ), f"Tokenizer must have either bos or eos token, but found none for {tokenizer_path}"

            # Check that tokenizer's vocab size is smaller than the model's vocab size
            assert (
                tokenizer.vocab_size <= trainer.model_config.vocab_size
            ), f"Tokenizer's vocab size ({tokenizer.vocab_size}) is larger than the model's vocab size ({trainer.model_config.vocab_size})"

            # Different processing for SFT vs pretraining
            if isinstance(data.dataset, SFTDatasetsArgs):
                # For SFT, use the dedicated prepare_sft_dataset function
                # Get optional debug parameter to limit dataset size (for faster development)
                debug_max_samples = getattr(data.dataset, "debug_max_samples", None)

                # Process the dataset using our dedicated SFT processing module
                train_dataset = prepare_sft_dataset(
                    raw_dataset=raw_dataset,
                    tokenizer=tokenizer,
                    trainer_sequence_length=trainer.sequence_length,
                    debug_max_samples=debug_max_samples,
                    num_proc=data.dataset.dataset_processing_num_proc_per_process,
                )
            else:
                # For pretraining, use existing CLM processing
                train_dataset = clm_process(
                    raw_dataset=raw_dataset,
                    tokenizer=tokenizer,
                    text_column_name=data.dataset.text_column_name,
                    dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process,
                    dataset_overwrite_cache=data.dataset.dataset_overwrite_cache,
                    sequence_length=trainer.sequence_length,
                )

            # We load the processed dataset on the ranks requiring it
            dataloader = get_train_dataloader(
                train_dataset=train_dataset,
                sequence_length=trainer.sequence_length,
                parallel_context=trainer.parallel_context,
                input_pp_rank=input_pp_rank,
                output_pp_rank=output_pp_rank,
                micro_batch_size=trainer.micro_batch_size,
                consumed_train_samples_stage=consumed_train_samples_stage,
                dataloader_num_workers=data.num_loading_workers,
                seed_worker=data.seed,
                dataloader_drop_last=True,
                use_position_ids=isinstance(trainer.model_config, Qwen2Config),
                sequence_sep_tokens=sequence_sep_tokens,  # Used to generate position ids
            )

            # Check if we have enough samples for train_steps
            total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
            num_tokens_needed_for_training = (
                num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
            )
            assert num_tokens_needed_for_training <= total_tokens_dataset, (
                f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
                f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
            )

    # Case 3: Nanosets
    elif isinstance(data.dataset, NanosetDatasetsArgs):
        log_rank("Using TokenizedBytes Dataloader", logger=logger, level=logging.INFO, rank=0)
        from nanotron.data.tokenized_bytes import get_tb_dataloader, get_tb_datasets

        tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        assert (
            len(tokenizer) == trainer.model_config.vocab_size
        ), f"Tokenizer vocab size ({len(tokenizer)}) does not match model config vocab size ({trainer.model_config.vocab_size}). "
        log_rank(
            f"[TokenizedBytes] Creating TokenizedBytes with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples",
            logger=logger,
            level=logging.INFO,
            rank=0,
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"

        start_time = time.time()
        train_dataset, data_log = get_tb_datasets(
            config=data.dataset,
            global_batch_size=trainer.global_batch_size,
            sequence_length=trainer.sequence_length,
            train_steps=trainer.config.tokens.train_steps,
            current_iteration=trainer.iteration_step,
            parallel_context=trainer.parallel_context,
            shuffle=data.dataset.shuffle_files,
            eos_token_id=tokenizer.eos_token_id,
            seed=data.seed,
            consumed_samples=consumed_train_samples_stage,
            consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder,
            last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder,
        )
        dataloader = get_tb_dataloader(
            dataset=train_dataset,
            sequence_length=trainer.sequence_length,
            micro_batch_size=trainer.micro_batch_size,
            global_batch_size=trainer.global_batch_size,
            num_workers=data.num_loading_workers,
            cfg=data.dataset,
            consumed_samples=consumed_train_samples_stage,
            num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, # TODO: this overshoots what's needed by the current stage, but it doesnt matter?
            parallel_context=trainer.parallel_context,
            input_pp_rank=input_pp_rank,
            output_pp_rank=output_pp_rank,
            dataloader_drop_last=True,
            dataloader_pin_memory=True,
            use_position_ids=isinstance(trainer.model_config, Qwen2Config),
            use_doc_masking=getattr(trainer.model_config, "_use_doc_masking", None),
        )
        log_rank(
            f"[TokenizedBytes] Time taken to create TokenizedBytes: {time.strftime('%M:%S', time.gmtime(time.time() - start_time))} (MM:SS)",
            logger=logger,
            level=logging.INFO,
            rank=0,
        )
        dist.barrier()

        # Create Nanoset
        # from nanotron.data.nanoset import Nanoset

        # with main_rank_first(trainer.parallel_context.world_pg):
        #     tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
        #     tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        #     eos_token_id = tokenizer.eos_token_id
        #     assert (
        #         eos_token_id is not None or data.dataset.return_positions is False
        #     ), "Tokenizer must have an eos token if return_positions is True"
        #     log_rank(
        #         f"[Nanoset] Creating Nanoset with {len(data.dataset.dataset_folder)} dataset folders and {trainer.config.tokens.train_steps * trainer.global_batch_size} train samples",
        #         logger=logger,
        #         level=logging.INFO,
        #         rank=0,
        #     )
        #     start_time = time.time()
        #     train_dataset = Nanoset(
        #         dataset_folders=data.dataset.dataset_folder,
        #         sequence_length=trainer.sequence_length,
        #         token_size=data.dataset.token_size_in_bytes,
        #         train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
        #         dataset_weights=data.dataset.dataset_weights,
        #         random_seed=data.seed,
        #         return_positions=data.dataset.return_positions,
        #         eos_token_id=eos_token_id,
        #     )
        #     end_time = time.time()
        #     log_rank(
        #         f"[Nanoset] Time taken to create Nanoset: {time.strftime('%M:%S', time.gmtime(end_time - start_time))} (MM:SS)",
        #         logger=logger,
        #         level=logging.INFO,
        #         rank=0,
        #     )
        # # Prepare dataloader
        # train_dataloader = build_nanoset_dataloader(
        #     train_dataset,
        #     trainer.sequence_length,
        #     parallel_context=trainer.parallel_context,
        #     input_pp_rank=input_pp_rank,
        #     output_pp_rank=output_pp_rank,
        #     micro_batch_size=trainer.micro_batch_size,
        #     consumed_train_samples=consumed_train_samples,
        #     dataloader_num_workers=data.num_loading_workers,
        #     dataloader_drop_last=True,
        #     use_position_ids=isinstance(trainer.model_config, Qwen2Config),
        #     use_doc_masking=False,
        #     dataloader_pin_memory=True,
        # )
        # dist.barrier()

    else:
        raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")

    if sanity_check_dataloader_interval is not None:
        sanity_check_dataloader(
            dataloader,
            tokenizer_path=trainer.config.tokenizer.tokenizer_name_or_path,
            sanity_check_dataloader_interval=sanity_check_dataloader_interval,
        )

    return dataloader


def get_dataloader(
    trainer: DistributedTrainer, sanity_check_dataloader_interval: Optional[int] = None
) -> Dict[str, DataLoader]:
    dataloaders = {}

    # Print training plan
    log_rank("Training plan", logger=logger, level=logging.INFO, rank=0, is_separator=True)
    stages_info = "".join(
        f"[Stage {stage.name}] start from step {stage.start_training_step} \n" for stage in trainer.config.data_stages
    )
    full_log_message = f"There are {len(trainer.config.data_stages)} training stages \n{stages_info}"
    log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)

    current_stage = None
    # WARNING: we assume we train on last stage
    stage_idx = len(trainer.config.data_stages) - 1
    stage_args = trainer.config.data_stages[stage_idx]
    if trainer.iteration_step+1 == stage_args.start_training_step:
        log_rank(f"Starting new stage {stage_args.name}", logger=logger, level=logging.INFO, rank=0)
        # we start a new stage
        if stage_idx >= len(trainer.metadata.data_stages):
            trainer.metadata.data_stages.append(DataStageMetadata(
                name=stage_args.name,
                start_training_step=stage_args.start_training_step,
                consumed_train_samples=0,
                consumed_tokens_per_dataset_folder={},
                sequence_length=trainer.sequence_length,
            ))
    elif len(trainer.metadata.data_stages) < len(trainer.config.data_stages):
        raise ValueError(f"If you're trying to start a new stage, you need to set `start_training_step` to the step after the last stage's: {trainer.iteration_step+1}")
    current_stage = trainer.metadata.data_stages[stage_idx]
    cur_stage_consumed_train_samples = current_stage.consumed_train_samples
    consumed_tokens_per_dataset_folder = current_stage.consumed_tokens_per_dataset_folder
    stage_args_data = trainer.config.data_stages[stage_idx].data

    num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
        current_stage, trainer.config, trainer.metadata
    ) # TODO: check this
    log_rank(
        f"Current stage: {current_stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {cur_stage_consumed_train_samples} samples"
        f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}",
        logger=logger,
        level=logging.INFO,
        rank=0,
    )

    # warn that if seqlen of stage - 1 has changed, consumed_train_samples=0 so we'll assume we're reading from new folder (so that we can resume training)
    if current_stage.sequence_length != trainer.metadata.data_stages[-1].sequence_length:
        raise NotImplementedError("We don't support changing sequence length between stages yet")
        if current_stage.consumed_train_samples == 0:
            log_rank(
                f"Warning: The sequence length of the last stage has changed from {trainer.metadata.data_stages[-1].sequence_length} to {current_stage.sequence_length}. We'll assume we're reading from the beginning of the dataset folders.",
                logger=logger,
                level=logging.WARNING,
                rank=0,
            )
        else:
            # we're resuming training, so that's fine
            pass
        cur_stage_consumed_train_samples = current_stage.consumed_train_samples

    else:
        # Prepare last_stages_consumed_tokens_per_dataset_folder which will be used to offset BlendableDataset to avoid reseeing consumed tokens even when sampler has restarted for this stage
        last_stages_consumed_tokens_per_dataset_folder = {}
        for stage in trainer.metadata.data_stages[:-1]:
            for folder_path, consumed_tokens in stage.consumed_tokens_per_dataset_folder.items():
                last_stages_consumed_tokens_per_dataset_folder[folder_path] = last_stages_consumed_tokens_per_dataset_folder.get(folder_path, 0) + consumed_tokens  



    dataloaders[current_stage.name] = get_dataloader_from_data_stage(
        trainer,
        stage_args_data,
        consumed_train_samples_stage=cur_stage_consumed_train_samples,
        consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder,
        last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder,
        num_remaining_train_steps=num_remaining_train_steps,
        sanity_check_dataloader_interval=sanity_check_dataloader_interval,
    )
    return dataloaders


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
    parser.add_argument(
        "--sanity-check-dataloader-interval",
        type=int,
        default=None,
        help="Optional interval to print dataloader samples",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    config_file = args.config_file

    # Load trainer and data
    trainer = DistributedTrainer(config_file)
    dataloader = get_dataloader(trainer, args.sanity_check_dataloader_interval)

    # Train
    trainer.train(dataloader)
