weak_to_strong/train.py (262 lines of code) (raw):

import itertools import os import pickle import time from dataclasses import dataclass from typing import Callable, Optional import datasets import numpy as np import torch import torch_optimizer as toptim from transformers.modeling_utils import load_sharded_checkpoint import weak_to_strong.logger as logger from weak_to_strong.common import clear_mem from weak_to_strong.eval import eval_model_acc from weak_to_strong.loss import xent_loss from weak_to_strong.model import TransformerWithHead @dataclass class ModelConfig: name: str default_lr: float eval_batch_size: int custom_kwargs: Optional[dict] = None gradient_checkpointing: bool = False model_parallel: bool = False default_optimizer: str = "adam" def train_model( model: torch.nn.Module, ds: datasets.Dataset, batch_size: int, lr: float = 1e-5, loss_fn: Callable = xent_loss, log_every: int = 10, eval_every: int = 100, eval_batch_size: int = 256, minibatch_size: int = 8, eval_ds: Optional[datasets.Dataset] = None, gradient_checkpointing: bool = False, train_with_dropout: bool = False, epochs: int = 1, lr_schedule: str = "cosine_anneal", optimizer_name: str = "adam", ): print("LR", lr, "batch_size", batch_size, "minibatch_size", minibatch_size) assert batch_size % minibatch_size == 0, "batch size must be divisible by minibatch size" # we purposefully turn off dropout, for determinism # this seems to help for 1 epoch finetuning anyways if train_with_dropout: model.train() else: model.eval() if gradient_checkpointing: ( model if hasattr(model, "gradient_checkpointing_enable") else model.module ).gradient_checkpointing_enable() nsteps = len(ds) * epochs // batch_size def lr_schedule_fn(step): if lr_schedule == "constant": return 1 else: assert False, f"invalid lr schedule, {lr_schedule}, must be constant or cosine_anneal" if optimizer_name.lower() == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=lr) elif optimizer_name.lower() == "adafactor": optimizer = toptim.Adafactor(model.parameters(), lr=lr) else: assert False, f"invalid optimizer {optimizer_name}, must be adam or adafactor" if lr_schedule == "cosine_anneal": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, nsteps) else: lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule_fn) step = 0 it = itertools.chain.from_iterable(itertools.repeat(ds, epochs)) losses = [] accuracies = [] eval_acc_dict = {} # If the model is wrapped by DataParallel, it doesn't have a device. In this case, # we use GPU 0 as the output device. This sadly means that this device will store # a bit more data than other ones, but hopefully should not be too big of a deal. io_device = model.device if hasattr(model, "device") else 0 while step < nsteps: loss_tot = 0 if eval_every and (step + 1) % eval_every == 0: eval_results = eval_model_acc(model, eval_ds, eval_batch_size) if gradient_checkpointing: ( model if hasattr(model, "gradient_checkpointing_enable") else model.module ).gradient_checkpointing_enable() if train_with_dropout: model.train() eval_accs = np.mean([r["acc"] for r in eval_results]) eval_acc_dict[step] = eval_accs logger.logkv("eval_accuracy", eval_accs) all_logits = [] all_labels = [] for i in range(batch_size // minibatch_size): try: mbatch = [next(it) for _ in range(minibatch_size)] except StopIteration: break input_ids = ( torch.nn.utils.rnn.pad_sequence([torch.tensor(ex["input_ids"]) for ex in mbatch]) .transpose( 0, 1, ) .to(io_device) ) labels = torch.tensor([ex["soft_label"] for ex in mbatch]).to(io_device) logits = model(input_ids) all_logits.extend(logits.to(io_device)) all_labels.extend(labels) all_logits = torch.stack(all_logits) all_labels = torch.stack(all_labels) loss = loss_fn(all_logits, all_labels, step_frac=step / nsteps) loss_tot += loss.item() loss.backward() losses.append(loss_tot) accuracies.append( torch.mean( (torch.argmax(all_logits, dim=1) == torch.argmax(all_labels, dim=1)).to( torch.float32 ) ).item() ) logger.logkvs( { "step": step, "progress": step / nsteps, "loss": loss_tot, "train_accuracy": accuracies[-1], "lr": lr_scheduler.get_last_lr()[0], } ) optimizer.step() optimizer.zero_grad() lr_scheduler.step() if log_every and step % log_every == 0: print( f"Step: {step}/{nsteps} Recent losses: {np.mean(losses)} {np.mean(accuracies)} {len(losses)}" ) losses = [] accuracies = [] step += 1 logger.dumpkvs() final_eval_results = None if eval_every: print("Final evaluation:") final_eval_results = eval_model_acc(model, eval_ds, eval_batch_size) logger.logkv("eval_accuracy", np.mean([r["acc"] for r in final_eval_results])) logger.dumpkvs() return final_eval_results def train_and_save_model( model_config: ModelConfig, train_ds: datasets.Dataset, test_ds: datasets.Dataset, inference_ds: Optional[datasets.Dataset] = None, *, batch_size: int, lr: float, epochs: int, eval_batch_size: Optional[int] = None, minibatch_size_per_device: Optional[int] = None, save_path: Optional[str] = None, loss_fn: Callable = xent_loss, label: str = "default", force_retrain: bool = False, train_with_dropout: bool = False, linear_probe: bool = False, lr_schedule: str = "constant", optimizer_name: str = "adam", eval_every: Optional[int] = None, ): if eval_batch_size is None: eval_batch_size = batch_size if minibatch_size_per_device is None: minibatch_size_per_device = 1 gradient_checkpointing = model_config.gradient_checkpointing custom_kwargs = model_config.custom_kwargs or {} def maybe_load_model(model): if os.path.exists(os.path.join(save_path, "results.pkl")) and not force_retrain: print("loading from", save_path) checkpoint_path = os.path.join(save_path, "pytorch_model.bin") if not os.path.exists(checkpoint_path): # Assume this means we have a sharded checkpoint, and load it appropriately load_sharded_checkpoint(model, checkpoint_path) else: state_dict = torch.load(os.path.join(save_path, "pytorch_model.bin")) state_dict = { k.replace("transformer.module", "transformer"): v for (k, v) in state_dict.items() } custom_kwargs["state_dict"] = state_dict return True return False already_trained = False # Load the model if model_config.model_parallel: assert torch.cuda.device_count() > 1, f"you might want more gpus for {model_config.name}" model = TransformerWithHead.from_pretrained( model_config.name, num_labels=2, device_map="auto", linear_probe=linear_probe, **custom_kwargs, ) already_trained = maybe_load_model(model) # slight misnomer, more like minibatch_size_per_dp_replica minibatch_size = minibatch_size_per_device else: model = TransformerWithHead.from_pretrained( model_config.name, num_labels=2, linear_probe=linear_probe, **custom_kwargs ).to("cuda") already_trained = maybe_load_model(model) # data parallel: currently not supported with model parallel minibatch_size = min(minibatch_size_per_device * torch.cuda.device_count(), batch_size) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model, output_device=0) print( "Using", torch.cuda.device_count(), "GPUs, setting minibatch_size to", minibatch_size, ) else: minibatch_size = minibatch_size_per_device if already_trained: test_results = eval_model_acc(model, test_ds, eval_batch_size) else: start = time.time() test_results = train_model( model, train_ds, batch_size, lr=lr, epochs=epochs, eval_ds=test_ds, gradient_checkpointing=gradient_checkpointing, loss_fn=loss_fn, eval_batch_size=eval_batch_size, eval_every=eval_every, minibatch_size=minibatch_size, train_with_dropout=train_with_dropout, lr_schedule=lr_schedule, optimizer_name=optimizer_name, ) print("Model training took", time.time() - start, "seconds") if save_path: # Note: If the model is wrapped by DataParallel, we need to unwrap it before saving (model if hasattr(model, "save_pretrained") else model.module).save_pretrained( save_path ) print("saved", save_path) inference_results = None if inference_ds: inference_results = eval_model_acc(model, inference_ds, eval_batch_size) logger.logkv("inference_accuracy", np.mean([r["acc"] for r in inference_results])) if save_path: with open(os.path.join(save_path, "results.pkl"), "wb") as f: pickle.dump( { "avg_acc_test": float(np.mean([r["acc"] for r in test_results])), "avg_acc_inference": float( np.mean([r["acc"] for r in inference_results] if inference_results else []) ), "test_results": test_results, "inference_results": inference_results if inference_results else [], }, f, ) # try to clean up memory clear_mem() logger.shutdown() return test_results, inference_results