import os
from pathlib import Path
import subprocess
import sys
import tempfile
from datetime import datetime

from nanotron.logging import human_format
from nanotron.models.llama import LlamaConfig

from datatrove.io import get_datafolder
from nanotron.config import DatasetStageArgs, NanosetDatasetsArgs, S3UploadArgs


# Paths
LOCAL_TMP_PATH_ON_NODE = f"/scratch/{os.environ.get('USER')}"
LAUNCH_CONFIGS_PATH = f"path/to/launch-configs"

# Executables
NANOTRON_RUN_TRAIN_SCRIPT = f"path/to/run_train.py"
S5CMD_PATH = "path/to/s5cmd"

S3_CHECKPOINTS_PREFIX = "path/to/where_to_save_checkpoints"

# Logging parameters
LOGS_PATH = f"path/to/slurm-logs"
REPO_ID = f"id of the repo to use for logging"
PROJECT = "name of the project"
EMAIL = "email to send notifications to"

# Resources parameters
NUM_GPUS = 8
NUM_CPUS_IN_NODE = 88
CPUS_PER_GPU = NUM_CPUS_IN_NODE // NUM_GPUS


model_config = LlamaConfig(
    # Config for a 1.46B model
    bos_token_id=1,
    eos_token_id=2,
    hidden_act="silu",
    hidden_size=2048,
    initializer_range=0.02,
    intermediate_size=8192,
    max_position_embeddings=2048,
    num_attention_heads=32,
    num_hidden_layers=14,
    num_key_value_heads=32,
    pretraining_tp=1,
    rms_norm_eps=1e-05,
    rope_scaling=None,
    tie_word_embeddings=True,
    use_cache=True,
    vocab_size=256008,  # gemma tokenizer + some room
)


num_params = human_format(
    model_config.vocab_size * model_config.hidden_size +
    model_config.num_hidden_layers
    * (
            3 * model_config.hidden_size * model_config.intermediate_size
            + 4 * model_config.hidden_size * model_config.hidden_size
    )
).replace(".", "p")

print(f"Model has {num_params} parameters")


def launch_slurm_job(launch_file_contents, *args):
    """
        Small helper function to save a sbatch script and call it.
    Args:
        launch_file_contents: Contents of the sbatch script
        *args: any other arguments to pass to the sbatch command

    Returns: the id of the launched slurm job

    """
    with tempfile.NamedTemporaryFile("w") as f:
        f.write(launch_file_contents)
        f.flush()
        return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1]


if __name__ == "__main__":
    import argparse
    from dataclasses import fields, is_dataclass

    from nanotron.config import get_config_from_file

    parser = argparse.ArgumentParser()
    parser.add_argument("data", help="dataset folder", type=str)
    parser.add_argument("run_name", help="run name", type=str)
    parser.add_argument("language", help="language", type=str)
    parser.add_argument("-d", help="dependency job", type=str, default=None)
    parser.add_argument("--seed", help="seed", type=int, default=6)
    parser.add_argument("--train_steps", "-ts", help="training steps. Total_toks=seq_len*steps*micro_bs*batch_accum_per_replica*dp_size", type=int, default=14000)
    parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal")
    args = parser.parse_args()
    SEED = args.seed


    dataset_name = run_name = args.run_name.replace(" ", "_")

    # Specific name for this run (checkpoints/logs/tensorboard)
    RUN = f"{num_params}-{dataset_name}-seed-{SEED}"

    df = get_datafolder(f"{S3_CHECKPOINTS_PREFIX}/{RUN}")
    if df.exists("latest.txt") and df.cat_file("latest.txt") == bytes(str(args.train_steps), "utf-8"):
        print(f"Not launching as latest checkpoint is already {args.train_steps} steps")
        sys.exit(0)

    import torch

    from nanotron.config import (
        CheckpointsArgs,
        Config,
        DataArgs,
        GeneralArgs,
        LlamaConfig,
        LoggingArgs,
        LRSchedulerArgs,
        ModelArgs,
        OptimizerArgs,
        ParallelismArgs,
        RandomInit,
        TokenizerArgs,
        TokensArgs,
        AdamWOptimizerArgs,
    )

    def print_differences(target, updates):
        if not is_dataclass(target) or not is_dataclass(updates):
            raise ValueError("Both target and updates should be dataclass instances")

        for field in fields(target):
            update_value = getattr(updates, field.name)

            if update_value is not None:
                if is_dataclass(update_value):
                    print_differences(getattr(target, field.name), update_value)
                else:
                    target_value = getattr(target, field.name)
                    if update_value != target_value:
                        if update_value.__class__.__module__ != "builtins":
                            continue
                        print(f"{field.name}: {target_value} -> {update_value}")

    data = [
        DatasetStageArgs(
            name="Training Stage",
            start_training_step=1,
            data=DataArgs(
                seed=SEED,
                num_loading_workers=0,
                dataset=NanosetDatasetsArgs(
                    dataset_folder=args.data if not args.data.startswith("s3://") else f"{LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/",
                    dataset_weights=None,
                )
            )
        ),
    ]

    general = GeneralArgs(
        project=PROJECT,
        run=RUN,
        ignore_sanity_checks=True,
        seed=SEED,
    )

    checkpoints = CheckpointsArgs(
        checkpoints_path=Path(f"{LOCAL_TMP_PATH_ON_NODE}/checkpoints/{RUN}"),
        checkpoints_path_is_shared_file_system=False,
        checkpoint_interval=500,
        save_initial_state=True,
    )

    parallelism = ParallelismArgs(
        dp=64,
        pp=1,
        tp=1,
        pp_engine="1f1b",
        tp_mode="REDUCE_SCATTER",
        tp_linear_async_communication=True,
    )
    # num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1))
    # parallelism.dp=int(num_nodes*8//parallelism.pp//parallelism.tp),  # How many remaining GPU when taking into account PP, TP and 8 GPUs per node

    tokens = TokensArgs(
        batch_accumulation_per_replica=4,
        micro_batch_size=4,
        sequence_length=2048,
        train_steps=args.train_steps,
        val_check_interval=-1,
    )

    model = ModelArgs(
        model_config=model_config,
        make_vocab_size_divisible_by=1,
        init_method=RandomInit(
            std=0.02
        ),
        dtype=torch.bfloat16,
    )

    logging = LoggingArgs(
        # 'debug', 'info', 'warning', 'error', 'critical' and 'passive'
        log_level="info",
        log_level_replica="info",
        iteration_step_info_interval=1,
    )

    optimizer = OptimizerArgs(
        accumulate_grad_in_fp32=True,
        clip_grad=1.0,
        weight_decay=0.1,
        zero_stage=0,
        learning_rate_scheduler=LRSchedulerArgs(
            learning_rate=3e-4,
            lr_warmup_steps=500,
            lr_warmup_style="linear",
            lr_decay_style="cosine",
            min_decay_lr=3.0e-5
        ),
        optimizer_factory=AdamWOptimizerArgs(
            adam_beta1=0.9,
            adam_beta2=0.95,
            adam_eps=1.0e-8,
            torch_adam_is_fused=True,
        ),
    )

    tokenizer = TokenizerArgs(
        tokenizer_name_or_path="google/gemma-7b",
    )

    s3_upload = S3UploadArgs(
        upload_s3_path=f"{S3_CHECKPOINTS_PREFIX}/{RUN}",
        remove_after_upload=True,
        s5cmd_numworkers=16,
        s5cmd_concurrency=5,
        s5cmd_path=S5CMD_PATH,
    )

    config = Config(
        general=general,
        checkpoints=checkpoints,
        parallelism=parallelism,
        model=model,
        tokenizer=tokenizer,
        logging=logging,
        tokens=tokens,
        optimizer=optimizer,
        data_stages=data,
        profiler=None,
        s3_upload=s3_upload,
        lighteval=None,
    )

    NODES = 8
    #### DEBUG MODE
    if os.environ.get("DEBUG_MODE", "0") != "0":
        print("##### WARNING DEBUG MODE #####")
        config.parallelism.dp = 2
        config.parallelism.pp = 2
        config.parallelism.tp = 2
        config.tokens.micro_batch_size = 3
        config.tokens.batch_accumulation_per_replica = 2
        config.checkpoints.save_initial_state = True
        NODES = 1

    # Sanity check that we can load, save to YAML and reload the config
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    os.makedirs(f"{LAUNCH_CONFIGS_PATH}/{run_name}", exist_ok=True)
    config_path_yaml = f"{LAUNCH_CONFIGS_PATH}/{run_name}/{timestamp}.yaml"
    config.save_as_yaml(config_path_yaml)
    config2 = get_config_from_file(config_path_yaml, config_class=Config)
    print_differences(config, config2)

    os.makedirs(f"{LOGS_PATH}/{run_name}", exist_ok=True)

    dataset_download_cmd =  "" if not args.data.startswith("s3://") else f"srun --ntasks-per-node=1 rm -rf {LOCAL_TMP_PATH_ON_NODE}/dataset\nsrun --ntasks-per-node=1 s5cmd cp '{args.data.removesuffix('/')}/*' {LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/"
    job_name = f"{run_name}-{SEED}"

    sbatch_script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --nodes={NODES}
#SBATCH --ntasks-per-node=1          # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task={NUM_CPUS_IN_NODE}
#SBATCH --gres=gpu:{NUM_GPUS}
#SBATCH --partition=hopper-prod
#SBATCH --output={LOGS_PATH}/{run_name}/train-{timestamp}-%x-%j
# #SBATCH --array=1-1%1
#SBATCH --qos={args.priority}
#SBATCH --begin=now+0minutes
#SBATCH --mail-type=ALL
#SBATCH --mail-user={EMAIL}
#SBATCH --requeue
{"#SBATCH --dependency=afterok:" + args.d if args.d else ""}

###########################################
# [BEGINING] ADAPT TO YOUR ENVIRONMENT


# [END] ADAPT TO YOUR ENVIRONMENT
###########################################


set -x -e

##### TO UPDATE #####


##### END TO UPDATE ######

echo "START TIME: $(date)"
secs_to_human(){{
    echo "$(( ${{1}} / 3600 )):$(( (${{1}} / 60) % 60 )):$(( ${{1}} % 60 ))"
}}
start=$(date +%s)
echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n"

{dataset_download_cmd}

# SLURM stuff
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=$((1024 + RANDOM % 64511))
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`

export TMPDIR={LOCAL_TMP_PATH_ON_NODE}
export CUDA_DEVICE_MAX_CONNECTIONS="1"

module load cuda/12.1

echo go $COUNT_NODE
echo $HOSTNAMES

##### MOVE TO YAML ######

CMD=" \
    {NANOTRON_RUN_TRAIN_SCRIPT} \
    --config-file {config_path_yaml}
    "

export LAUNCHER="python -u -m torch.distributed.run \
    --nproc_per_node {NUM_GPUS} \
    --nnodes $COUNT_NODE \
    --rdzv-backend c10d \
    --rdzv-endpoint $MASTER_ADDR:$MASTER_PORT \
    --rdzv-id $SLURM_JOB_ID \
    --node_rank $SLURM_PROCID \
    --role $SLURMD_NODENAME: \
    --max_restarts 0 \
    --tee 3 \
    "

# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub
random_milliseconds=$(( RANDOM % 1001 ))
sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000")
echo "Sleeping for $sleep_time seconds..."
sleep $sleep_time

launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD"

srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD"


echo "END TIME: $(date)"

{
    "" if not args.data.startswith("s3://") else f"srun --ntasks-per-node=1 rm -rf {LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/"
}
"""
    id = launch_slurm_job(sbatch_script)
    log_path = f"{LOGS_PATH}/{run_name}/train-{timestamp}-{job_name}-{id}"
    print(f"Launched with Slurm job id={id}")
    print(f"To view the logs, use the command: tail -f {log_path}")
