step1_modeling/train.py (59 lines of code) (raw):

""" torchrun --nproc_per_node 1 train.py """ import os import datetime import torch import torch.nn.functional as F import torch.distributed as dist import argparse from torch.optim import AdamW from transformers import AutoConfig from model import Llama from utils import set_all_seed, print if __name__ == "__main__": parser = argparse.ArgumentParser(description="Training script for LLaMA model") # Environment arguments parser.add_argument("--omp_num_threads", type=str, default="1") parser.add_argument("--tokenizers_parallelism", type=str, default="false") # Model arguments parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") parser.add_argument("--num_hidden_layers", type=int, default=32) parser.add_argument("--num_attention_heads", type=int, default=16) parser.add_argument("--num_key_value_heads", type=int, default=4) # Training arguments parser.add_argument("--seed", type=int, default=42) parser.add_argument("--learning_rate", type=float, default=3e-4) parser.add_argument("--seq_len", type=int, default=32) parser.add_argument("--micro_batch_size", type=int, default=1) # Logging arguments parser.add_argument("--run_name", type=str, default="default_run") parser.add_argument("--use_wandb", action="store_true") args = parser.parse_args() # Set environment variables os.environ["OMP_NUM_THREADS"] = args.omp_num_threads os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism os.environ["DEVICE"] = "cuda" local_rank = int(os.environ["LOCAL_RANK"]) global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) backend = "nccl" torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) dtype = torch.bfloat16 dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) set_all_seed(args.seed) model_config = AutoConfig.from_pretrained(args.model_name) model_config.num_hidden_layers = args.num_hidden_layers model_config.num_attention_heads = args.num_attention_heads model_config.num_key_value_heads = args.num_key_value_heads model_config.max_position_embeddings = args.seq_len model = Llama(config=model_config) model.to(dtype).to(device) model.train() dist.barrier() optimizer = AdamW(model.parameters(), lr=args.learning_rate) dist.barrier() # Create dummy data input_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) target_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) # Training step optimizer.zero_grad() # Forward pass outputs = model(input_ids=input_ids) # Compute loss target_ids = target_ids.reshape(-1) outputs = outputs.view(-1, model_config.vocab_size) loss = F.cross_entropy(outputs, target_ids) # Backward pass loss.backward() # Optimizer step optimizer.step() print(f"Loss: {loss.item():.4f}", is_print_rank=(global_rank == 0)) dist.destroy_process_group()