# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import math
import os
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union

import lightning as L
import numpy as np
import nvtx
import torch
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities import ThroughputMonitor, measure_flops
from torch.utils.data import DataLoader, IterableDataset

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Config
from lit_gpt.args import EvalArgs, IOArgs, TrainArgs
from lit_gpt.model import GPT, Block
from lit_gpt.utils import chunked_cross_entropy, estimate_flops, get_default_supported_precision, num_parameters

from utilities.nsight_callbacks import NsightCallback

use_nsight = os.getenv("COLLECT_NSYS_PROFILE") == "yes"
if (use_nsight):
    import utilities.monitor_collectives
    utilities.monitor_collectives.shunt_torch_communication()
    print("Enabling nsight profiling.")


def setup(
    model_name: str = os.getenv("MODEL_NAME", "Llama-2-70b-hf"),
    data_dir: Path = Path("/data"),
    out_dir: Path = Path(os.getenv("EXPERIMENT_LOCAL_DIR", "")) / "out",
    precision: Optional[str] = None,
    resume: Union[bool, Path] = False,
    eval_interval: int = 1000,
    save_interval: int = 1000,
    eval_iters: int = 100,
    log_interval: int = 1,
    devices: int = 4,
    learning_rate: float = 6e-4,
    weight_decay: float = 1e-1,
    beta1: float = 0.9,
    beta2: float = 0.95,
    lr_warmup_steps: int = 100,
    min_lr: float = 6e-5,
    global_batch_size: int = (int(os.getenv("NNODES", "1")) * 8 * int(os.getenv("BATCH_SIZE", "6"))),
    micro_batch_size: int = int(os.getenv("MICRO_BATCH_SIZE", "6")),
    max_norm: float = 1.0,
    epochs: int = int(os.getenv("NUMBER_OF_EPOCHS", "2")),
    train_epoch_size: int = 8 * int(os.getenv("MICRO_BATCH_SIZE", "6")) * int(os.getenv("STEPS_PER_EPOCH", "30")),
) -> None:
    print(locals())
    precision = precision or get_default_supported_precision(training=True)

    if devices > 1:
        strategy = FSDPStrategy(
            auto_wrap_policy={Block},
            activation_checkpointing_policy={Block},
            state_dict_type="full",
            limit_all_gathers=True,
            cpu_offload=False,
        )
    else:
        strategy = "auto"

    logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
    callbacks = []
    if use_nsight:
        callbacks.append(NsightCallback())
    fabric = L.Fabric(
        devices=devices, strategy=strategy, precision=precision,
        loggers=logger, callbacks=callbacks, num_nodes=int(os.getenv("NNODES", "1")))

    fabric.launch(
        main,
        devices,
        resume,
        Config.from_name(name=model_name),
        IOArgs(train_data_dir=data_dir, val_data_dir=data_dir, out_dir=out_dir),
        TrainArgs(
            save_interval=save_interval,
            log_interval=log_interval,
            global_batch_size=global_batch_size,
            micro_batch_size=micro_batch_size,
            lr_warmup_steps=lr_warmup_steps,
            epochs=epochs,
            epoch_size=train_epoch_size,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            beta1=beta1,
            beta2=beta2,
            max_norm=max_norm,
            min_lr=min_lr,
        ),
        EvalArgs(interval=eval_interval, max_iters=eval_iters),
    )


def main(
    fabric: L.Fabric,
    devices: int,
    resume: Union[bool, Path],
    config: Config,
    io_args: IOArgs,
    train_args: TrainArgs,
    eval_args: EvalArgs,
) -> None:
    validate_args(io_args, train_args, eval_args)

    if fabric.global_rank == 0:
        io_args.out_dir.mkdir(parents=True, exist_ok=True)

    fabric.seed_everything(1337, workers=True)  # same seed for every process to init model (FSDP)

    fabric.print(f"Loading model with {config.__dict__}")
    t0 = time.perf_counter()
    with fabric.init_module(empty_init=(fabric.world_size > 1)):
        model = GPT(config)
    model.apply(model._init_weights)

    fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
    fabric.print(f"Total parameters {num_parameters(model):,}")

    model = fabric.setup(model)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=train_args.learning_rate,
        weight_decay=train_args.weight_decay,
        betas=(train_args.beta1, train_args.beta2),
        foreach=False,
    )
    optimizer = fabric.setup_optimizers(optimizer)

    train_data, val_data = load_datasets(io_args, max_seq_length=model.max_seq_length)
    train_dataloader = DataLoader(train_data, batch_size=train_args.micro_batch_size, num_workers=2)
    val_dataloader = DataLoader(val_data, batch_size=train_args.micro_batch_size, num_workers=2)
    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)

    state = {"model": model, "optimizer": optimizer, "iter_num": 0, "step_count": 0}

    if resume is True:
        resume = max(io_args.out_dir.glob("*.pth"), key=lambda p: int(p.name.split("-")[1]))
    if resume:
        fabric.print(f"Resuming training from {resume}")
        fabric.load(resume, state)

    train_time = time.perf_counter()
    train(fabric, devices, state, train_dataloader, val_dataloader, io_args, train_args, eval_args)
    fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
    if fabric.device.type == "cuda":
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")


def train(
    fabric: L.Fabric,
    devices: int,
    state: dict,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    io_args: IOArgs,
    train_args: TrainArgs,
    eval_args: EvalArgs,
) -> None:
    model = state["model"]
    optimizer = state["optimizer"]

    validate(fabric, model, val_dataloader, max_iters=2)  # sanity check

    with torch.device("meta"):
        meta_model = GPT(model.config)
        # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
        # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
        # consider passing `flops_per_batch=estimated_flops` instead
        estimated_flops = estimate_flops(meta_model, training=True) * train_args.micro_batch_size
        fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
        x = torch.randint(0, 1, (train_args.micro_batch_size, model.max_seq_length))
        forward_fn = lambda: meta_model(x)
        loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
        measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
        fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
        del meta_model, x

    throughput = ThroughputMonitor(fabric, window_size=int(os.getenv("WINDOW_SIZE", "10")))
    total_t0 = time.perf_counter()

    train_iter = iter(train_dataloader)

    fabric.call("on_train_epoch_start")

    lr_warmup_iters = train_args.lr_warmup_steps * train_args.gradient_accumulation_iters(devices)
    for state["iter_num"] in range(state["iter_num"], train_args.max_iters(devices)):
        # determine and set the learning rate for this iteration
        lr = get_lr(
            train_args.learning_rate,
            state["iter_num"],
            lr_warmup_iters,
            train_args.max_iters(devices),
            min_lr=train_args.min_lr,
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        iter_num = state["iter_num"] + 1
        iter_t0 = time.perf_counter()

        input_ids, targets = next(train_iter)

        fabric.call("on_train_batch_start", iter_num, train_args.gradient_accumulation_iters(devices))

        is_accumulating = iter_num % train_args.gradient_accumulation_iters(devices) != 0
        with fabric.no_backward_sync(model, enabled=is_accumulating):
            # Forward pass
            logits = model(input_ids)
            loss = chunked_cross_entropy(logits, targets, chunk_size=0)
            
            # Backward pass
            fabric.call("on_before_backward")
            fabric.backward(loss / train_args.gradient_accumulation_iters(devices))
            fabric.call("on_after_backward")

        if not is_accumulating:
            fabric.clip_gradients(model, optimizer, max_norm=train_args.max_norm)
            with nvtx.annotate(color="orange"):
                optimizer.step()
                optimizer.zero_grad()
            state["step_count"] += 1

        if iter_num % train_args.log_interval == 0:
            loss_item = loss.item()  # expensive device-to-host synchronization
            t1 = time.perf_counter()
            throughput.update(
                time=t1 - total_t0,
                batches=iter_num,
                samples=iter_num * train_args.micro_batch_size,
                lengths=iter_num * train_args.micro_batch_size * model.max_seq_length,
                flops=measured_flops * train_args.log_interval,
            )
            throughput.compute_and_log(step=iter_num)
            fabric.print(
                f"iter {iter_num} step {state['step_count']}: loss {loss_item:.4f}, iter time:"
                f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
            )

        if not is_accumulating and state["step_count"] % eval_args.interval == 0:
            t0 = time.perf_counter()
            val_loss = validate(fabric, model, val_dataloader, max_iters=eval_args.max_iters)
            t1 = time.perf_counter() - t0
            fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms")
            fabric.barrier()

        fabric.call("on_train_batch_end", iter_num, train_args.gradient_accumulation_iters(devices))

        if not is_accumulating and state["step_count"] % train_args.save_interval == 0:
            checkpoint_path = io_args.out_dir / f"iter-{iter_num:06d}-ckpt.pth"
            fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}")
            fabric.save(checkpoint_path, state)


# FSDP has issues with `inference_mode`
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor:
    fabric.print("Validating ...")
    model.eval()
    val_iter = iter(val_dataloader)

    losses = torch.zeros(max_iters, device=fabric.device)
    for k in range(max_iters):
        input_ids, targets = next(val_iter)
        logits = model(input_ids)
        losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0)
    out = losses.mean()

    model.train()
    return out


def load_datasets(io_args: IOArgs, max_seq_length: int) -> Tuple["Dataset", "Dataset"]:
    train_data = Dataset(io_args.train_data_dir / "train.bin", max_seq_length)
    val_data = Dataset(io_args.val_data_dir / "val.bin", max_seq_length)
    return train_data, val_data


class Dataset(IterableDataset):
    def __init__(self, data_file: Path, max_seq_length: int):
        super().__init__()
        self.data_file = data_file
        self.max_seq_length = max_seq_length

    def __iter__(self):
        data = np.memmap(self.data_file, dtype=np.uint16, mode="r")
        while True:
            i = torch.randint(len(data) - self.max_seq_length, (1,)).item()
            x = torch.from_numpy((data[i : i + self.max_seq_length]).astype(np.int64))
            y = torch.from_numpy((data[i + 1 : i + 1 + self.max_seq_length]).astype(np.int64))
            yield x, y


# learning rate decay scheduler (cosine with linear warmup)
def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float:
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > max_iters, return min learning rate
    if it > max_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)


def validate_args(io_args: IOArgs, train_args: TrainArgs, eval_args: EvalArgs) -> None:
    unsupported = [(io_args, ["checkpoint_dir"]), (train_args, ["max_tokens"]), (eval_args, ["max_new_tokens"])]
    for args, names in unsupported:
        for name in names:
            if getattr(args, name) is not None:
                raise ValueError(f"{__file__} doesn't support the {name!r} argument. This is set in {args}")
    required = [(io_args, ["train_data_dir", "val_data_dir"]), (train_args, ["epoch_size", "epochs", "max_norm"])]
    for args, names in required:
        for name in names:
            if getattr(args, name) is None:
                raise ValueError(f"{__file__} requires the {name!r} argument. This is set in {args}")


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")

    from jsonargparse import CLI

    CLI(setup)