sample_workloads/lit-gpt-demo/openwebtext.py (269 lines of code) (raw):

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import math import os import sys import time from pathlib import Path from typing import Optional, Tuple, Union import lightning as L import numpy as np import nvtx import torch from lightning.fabric.loggers import CSVLogger from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities import ThroughputMonitor, measure_flops from torch.utils.data import DataLoader, IterableDataset # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) from lit_gpt import Config from lit_gpt.args import EvalArgs, IOArgs, TrainArgs from lit_gpt.model import GPT, Block from lit_gpt.utils import chunked_cross_entropy, estimate_flops, get_default_supported_precision, num_parameters from utilities.nsight_callbacks import NsightCallback use_nsight = os.getenv("COLLECT_NSYS_PROFILE") == "yes" if (use_nsight): import utilities.monitor_collectives utilities.monitor_collectives.shunt_torch_communication() print("Enabling nsight profiling.") def setup( model_name: str = os.getenv("MODEL_NAME", "Llama-2-70b-hf"), data_dir: Path = Path("/data"), out_dir: Path = Path(os.getenv("EXPERIMENT_LOCAL_DIR", "")) / "out", precision: Optional[str] = None, resume: Union[bool, Path] = False, eval_interval: int = 1000, save_interval: int = 1000, eval_iters: int = 100, log_interval: int = 1, devices: int = 4, learning_rate: float = 6e-4, weight_decay: float = 1e-1, beta1: float = 0.9, beta2: float = 0.95, lr_warmup_steps: int = 100, min_lr: float = 6e-5, global_batch_size: int = (int(os.getenv("NNODES", "1")) * 8 * int(os.getenv("BATCH_SIZE", "6"))), micro_batch_size: int = int(os.getenv("MICRO_BATCH_SIZE", "6")), max_norm: float = 1.0, epochs: int = int(os.getenv("NUMBER_OF_EPOCHS", "2")), train_epoch_size: int = 8 * int(os.getenv("MICRO_BATCH_SIZE", "6")) * int(os.getenv("STEPS_PER_EPOCH", "30")), ) -> None: print(locals()) precision = precision or get_default_supported_precision(training=True) if devices > 1: strategy = FSDPStrategy( auto_wrap_policy={Block}, activation_checkpointing_policy={Block}, state_dict_type="full", limit_all_gathers=True, cpu_offload=False, ) else: strategy = "auto" logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) callbacks = [] if use_nsight: callbacks.append(NsightCallback()) fabric = L.Fabric( devices=devices, strategy=strategy, precision=precision, loggers=logger, callbacks=callbacks, num_nodes=int(os.getenv("NNODES", "1"))) fabric.launch( main, devices, resume, Config.from_name(name=model_name), IOArgs(train_data_dir=data_dir, val_data_dir=data_dir, out_dir=out_dir), TrainArgs( save_interval=save_interval, log_interval=log_interval, global_batch_size=global_batch_size, micro_batch_size=micro_batch_size, lr_warmup_steps=lr_warmup_steps, epochs=epochs, epoch_size=train_epoch_size, learning_rate=learning_rate, weight_decay=weight_decay, beta1=beta1, beta2=beta2, max_norm=max_norm, min_lr=min_lr, ), EvalArgs(interval=eval_interval, max_iters=eval_iters), ) def main( fabric: L.Fabric, devices: int, resume: Union[bool, Path], config: Config, io_args: IOArgs, train_args: TrainArgs, eval_args: EvalArgs, ) -> None: validate_args(io_args, train_args, eval_args) if fabric.global_rank == 0: io_args.out_dir.mkdir(parents=True, exist_ok=True) fabric.seed_everything(1337, workers=True) # same seed for every process to init model (FSDP) fabric.print(f"Loading model with {config.__dict__}") t0 = time.perf_counter() with fabric.init_module(empty_init=(fabric.world_size > 1)): model = GPT(config) model.apply(model._init_weights) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") model = fabric.setup(model) optimizer = torch.optim.AdamW( model.parameters(), lr=train_args.learning_rate, weight_decay=train_args.weight_decay, betas=(train_args.beta1, train_args.beta2), foreach=False, ) optimizer = fabric.setup_optimizers(optimizer) train_data, val_data = load_datasets(io_args, max_seq_length=model.max_seq_length) train_dataloader = DataLoader(train_data, batch_size=train_args.micro_batch_size, num_workers=2) val_dataloader = DataLoader(val_data, batch_size=train_args.micro_batch_size, num_workers=2) train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) state = {"model": model, "optimizer": optimizer, "iter_num": 0, "step_count": 0} if resume is True: resume = max(io_args.out_dir.glob("*.pth"), key=lambda p: int(p.name.split("-")[1])) if resume: fabric.print(f"Resuming training from {resume}") fabric.load(resume, state) train_time = time.perf_counter() train(fabric, devices, state, train_dataloader, val_dataloader, io_args, train_args, eval_args) fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") if fabric.device.type == "cuda": fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") def train( fabric: L.Fabric, devices: int, state: dict, train_dataloader: DataLoader, val_dataloader: DataLoader, io_args: IOArgs, train_args: TrainArgs, eval_args: EvalArgs, ) -> None: model = state["model"] optimizer = state["optimizer"] validate(fabric, model, val_dataloader, max_iters=2) # sanity check with torch.device("meta"): meta_model = GPT(model.config) # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, # consider passing `flops_per_batch=estimated_flops` instead estimated_flops = estimate_flops(meta_model, training=True) * train_args.micro_batch_size fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") x = torch.randint(0, 1, (train_args.micro_batch_size, model.max_seq_length)) forward_fn = lambda: meta_model(x) loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0) measured_flops = measure_flops(meta_model, forward_fn, loss_fn) fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") del meta_model, x throughput = ThroughputMonitor(fabric, window_size=int(os.getenv("WINDOW_SIZE", "10"))) total_t0 = time.perf_counter() train_iter = iter(train_dataloader) fabric.call("on_train_epoch_start") lr_warmup_iters = train_args.lr_warmup_steps * train_args.gradient_accumulation_iters(devices) for state["iter_num"] in range(state["iter_num"], train_args.max_iters(devices)): # determine and set the learning rate for this iteration lr = get_lr( train_args.learning_rate, state["iter_num"], lr_warmup_iters, train_args.max_iters(devices), min_lr=train_args.min_lr, ) for param_group in optimizer.param_groups: param_group["lr"] = lr iter_num = state["iter_num"] + 1 iter_t0 = time.perf_counter() input_ids, targets = next(train_iter) fabric.call("on_train_batch_start", iter_num, train_args.gradient_accumulation_iters(devices)) is_accumulating = iter_num % train_args.gradient_accumulation_iters(devices) != 0 with fabric.no_backward_sync(model, enabled=is_accumulating): # Forward pass logits = model(input_ids) loss = chunked_cross_entropy(logits, targets, chunk_size=0) # Backward pass fabric.call("on_before_backward") fabric.backward(loss / train_args.gradient_accumulation_iters(devices)) fabric.call("on_after_backward") if not is_accumulating: fabric.clip_gradients(model, optimizer, max_norm=train_args.max_norm) with nvtx.annotate(color="orange"): optimizer.step() optimizer.zero_grad() state["step_count"] += 1 if iter_num % train_args.log_interval == 0: loss_item = loss.item() # expensive device-to-host synchronization t1 = time.perf_counter() throughput.update( time=t1 - total_t0, batches=iter_num, samples=iter_num * train_args.micro_batch_size, lengths=iter_num * train_args.micro_batch_size * model.max_seq_length, flops=measured_flops * train_args.log_interval, ) throughput.compute_and_log(step=iter_num) fabric.print( f"iter {iter_num} step {state['step_count']}: loss {loss_item:.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" ) if not is_accumulating and state["step_count"] % eval_args.interval == 0: t0 = time.perf_counter() val_loss = validate(fabric, model, val_dataloader, max_iters=eval_args.max_iters) t1 = time.perf_counter() - t0 fabric.print(f"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms") fabric.barrier() fabric.call("on_train_batch_end", iter_num, train_args.gradient_accumulation_iters(devices)) if not is_accumulating and state["step_count"] % train_args.save_interval == 0: checkpoint_path = io_args.out_dir / f"iter-{iter_num:06d}-ckpt.pth" fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") fabric.save(checkpoint_path, state) # FSDP has issues with `inference_mode` @torch.no_grad() def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor: fabric.print("Validating ...") model.eval() val_iter = iter(val_dataloader) losses = torch.zeros(max_iters, device=fabric.device) for k in range(max_iters): input_ids, targets = next(val_iter) logits = model(input_ids) losses[k] = chunked_cross_entropy(logits, targets, chunk_size=0) out = losses.mean() model.train() return out def load_datasets(io_args: IOArgs, max_seq_length: int) -> Tuple["Dataset", "Dataset"]: train_data = Dataset(io_args.train_data_dir / "train.bin", max_seq_length) val_data = Dataset(io_args.val_data_dir / "val.bin", max_seq_length) return train_data, val_data class Dataset(IterableDataset): def __init__(self, data_file: Path, max_seq_length: int): super().__init__() self.data_file = data_file self.max_seq_length = max_seq_length def __iter__(self): data = np.memmap(self.data_file, dtype=np.uint16, mode="r") while True: i = torch.randint(len(data) - self.max_seq_length, (1,)).item() x = torch.from_numpy((data[i : i + self.max_seq_length]).astype(np.int64)) y = torch.from_numpy((data[i + 1 : i + 1 + self.max_seq_length]).astype(np.int64)) yield x, y # learning rate decay scheduler (cosine with linear warmup) def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float: # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > max_iters, return min learning rate if it > max_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) def validate_args(io_args: IOArgs, train_args: TrainArgs, eval_args: EvalArgs) -> None: unsupported = [(io_args, ["checkpoint_dir"]), (train_args, ["max_tokens"]), (eval_args, ["max_new_tokens"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: raise ValueError(f"{__file__} doesn't support the {name!r} argument. This is set in {args}") required = [(io_args, ["train_data_dir", "val_data_dir"]), (train_args, ["epoch_size", "epochs", "max_norm"])] for args, names in required: for name in names: if getattr(args, name) is None: raise ValueError(f"{__file__} requires the {name!r} argument. This is set in {args}") if __name__ == "__main__": torch.set_float32_matmul_precision("high") from jsonargparse import CLI CLI(setup)