vision/m4/training/trainer.py (1,513 lines of code) (raw):

import copy import gc import json import logging import os import pickle import socket import subprocess import time from collections import defaultdict from pathlib import Path from typing import Dict, List, Union import accelerate import psutil import torch import torch.optim as torch_optim import transformers.optimization as transformers_optim import wandb from packaging import version from m4.training.config import ( DataParams, GlobalBatchSizeRampUpRunningParams, Hparams, OptimizerParams, Parameters, ResumeParams, ) from m4.training.dataset import DatasetNames from m4.training.debug_utils import validate_optim_states_are_reset from m4.training.utils import ( # deepspeed_gathered_parameters_context_manager, IMAGE_TOKEN, JSONEncoderForDataclasses, LoggingTypes, SigtermListener, get_deepspeed_engine, get_stats, get_stats_format, is_deepspeed_used, is_deepspeed_zero3_used, is_deepspeed_zero_init_enabled, lora_unload, mem_usage_formatted, pynmvl_handle, pynvml_get_total_energy_in_joules, ) from m4.utils.activation_tracker import ActivationTracker from m4.utils.debug import printflock as print from m4.utils.progress import BarColumn, MofNCompleteColumn, Progress, TaskProgressColumn, TimeElapsedColumn from m4.utils.training.timer import DeviceAgnosticTimer, Timer, format_secs_to_sec_fractions, format_secs_to_time logger = logging.getLogger(__name__) fqdn = socket.getfqdn() if "compute.internal" in fqdn: # hfc: 1.1TB RAM _MEMORY_EXPLOSION_THRESHOLD = 93.0 elif "idris.fr" in fqdn or "idrsrv" in fqdn: # jz: 0.5TB RAM _MEMORY_EXPLOSION_THRESHOLD = 90.0 else: _MEMORY_EXPLOSION_THRESHOLD = 90.0 METRICS_TO_DEFAULT_VALUE_FN = { "lr": lambda: None, "num_opt_steps": lambda: 0, "num_epochs": lambda: 0, "per_token_loss": lambda: defaultdict(lambda: None), "z_loss": lambda: defaultdict(lambda: None), "watt/s": lambda: defaultdict(lambda: None), "tflops": lambda: defaultdict(lambda: None), "tflop_counter": lambda: defaultdict(lambda: None), "fwd_bwd_time": lambda: defaultdict(lambda: None), "tflops_acc": lambda: defaultdict(lambda: None), "num_per_device_batches": lambda: defaultdict(lambda: None), "num_images": lambda: defaultdict(lambda: None), "num_image_tokens": lambda: defaultdict(lambda: None), "num_tokens": lambda: defaultdict(lambda: None), "image_to_text_ratio": lambda: defaultdict(lambda: None), "pixel_values_sum": lambda: defaultdict(lambda: None), "num_padding": lambda: defaultdict(lambda: None), "num_per_device_batches_in_curr_epoch": lambda: defaultdict(lambda: None), "num_batches": lambda: defaultdict(int), "num_batches_in_curr_epoch": lambda: defaultdict(int), "per_token_loss_acc": lambda: defaultdict(lambda: None), "z_loss_acc": lambda: defaultdict(lambda: None), "num_batches_since_training_logged": lambda: defaultdict(int), "num_per_device_batches_since_training_logged": lambda: defaultdict(lambda: None), "tflop_counter_since_training_logged": lambda: defaultdict(lambda: None), "total_energy_delta_since_training_logged": lambda: defaultdict(lambda: None), "fwd_bwd_time_since_training_logged": lambda: defaultdict(lambda: None), } METRICS_TO_RESET_AFTER_LOGGING = [ "per_token_loss_acc", "z_loss_acc", "num_batches_since_training_logged", "num_per_device_batches_since_training_logged", "tflop_counter_since_training_logged", "total_energy_delta_since_training_logged", "fwd_bwd_time_since_training_logged", ] class Trainer(object): """ Trainer object to monitor training and validation :- config -: json config file """ def __init__(self, accelerator, vl_model, tokenizer, train_loader, val_loader, config): # Initialize params self.config: Parameters = config self.optim_param: OptimizerParams = config.optim_param self.hparams: Hparams = config.hparams self.resume_param: ResumeParams = config.resume_param self.data_param: DataParams = config.data_param # Initialize last step directory self.last_opt_step_dir = "" # Initialize the model self.vl_model = vl_model # Gradient checkpointing if self.hparams.gradient_checkpointing: self.vl_model.gradient_checkpointing_enable() # Debug if accelerator.is_main_process and self.hparams.train_logging_activations: self.activation_tracker = ActivationTracker(self.vl_model) else: self.activation_tracker = None # Initialize tokenizer self.tokenizer = tokenizer self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) # Initialize accelerator self.accelerator = accelerator # Initialize loaders self.train_loader = train_loader self.val_loader = val_loader # Checks self._compatibility_checks() # Initialize everything related to distributed training self._configure_optimizer_and_scheduler() # Prepare and/or register model, optimizer, dataloaders and scheduler self._prepare_register() # now that we have num_processes, figure out batch_size-related variables self.setup_batch_size_related_configs() # Compute useful variables self.optim_param.opt_batch_size = self.hparams.global_batch_size if self.hparams.max_num_opt_steps is None and self.hparams.max_num_epochs is None: if hasattr(self.train_loader, "__len__") and self.hparams.global_batch_size_ramp_up.start is not None: raise ValueError("Currently global batch size ramp up doesn't work with MappedDataset") try: self.hparams.max_num_opt_steps = int(len(self.train_loader) // self.hparams.grad_acc_size) except TypeError: raise ValueError("max_num_opt_steps or max_num_epochs must be defined if you use IterableDataset") # self._set_model_tflops_per_batch_per_gpu() # Init trackers self._init_trackers() # Handle jz timing and memory self.jz_training_time_over = [False] self.memory_explosion = False # Stopping on demand self.kill_switch_activated = False # Sigterm signal listener self.sigterm_signal_received = False self.sigterm_listener = SigtermListener() sizes = defaultdict(int) trainable_params = [] numel_fn = lambda p: p.ds_numel if is_deepspeed_zero_init_enabled() else p.numel() # noqa for name, param in self.accelerator.unwrap_model(self.vl_model).named_parameters(): numel = numel_fn(param) sizes["total"] += numel sizes["total_lora"] += numel if "lora_" in name else 0 if "vision_model" in name: sizes["vision_model"] += numel sizes["vision_model_lora"] += numel if "lora_" in name else 0 if "perceiver_resampler" in name: sizes["perceiver_resampler"] += numel if "modality_projection" in name: sizes["modality_projection"] += numel if param.requires_grad: sizes["trainable"] += numel sizes["trainable_lora"] += numel if "lora_" in name else 0 trainable_params += [name] if self.accelerator.is_main_process: logger.info(f""" ------------------------------------- Model: - Total size: {sizes["total"]} ---- Lora size: {sizes["total_lora"]} - Vision encoder size: {sizes["vision_model"]} ---- Lora size: {sizes["vision_model_lora"]} - Perceiver resampler size: {sizes["perceiver_resampler"]} - Modality projection: {sizes["modality_projection"]} - Number of trainable parameters: {sizes["trainable"]} ---- Lora trainable parameters: {sizes["trainable_lora"]} Maximum number of optimizer steps: {self.hparams.max_num_opt_steps} Maximum number of epochs: {self.hparams.max_num_epochs if self.hparams.max_num_epochs else "N/A"} Number of gradient accumulation steps: {self.hparams.grad_acc_size} Number of processes: {self.accelerator.num_processes} Batch sizes: - Per device batch size: {self.hparams.batch_size_per_gpu} - Optimizer batch size: {self.optim_param.opt_batch_size} ------------------------------------- """) logger.info("Trainable/non-trainable parameters:") for name, param in vl_model.named_parameters(): logger.info(f" Name: {name} | Trainable: {param.requires_grad}") if len(self.hparams.train_logging_grad_param_deepspeed) > 0: from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state self.safe_get_full_fp32_param = safe_get_full_fp32_param self.safe_get_full_grad = safe_get_full_grad self.safe_get_full_optimizer_state = safe_get_full_optimizer_state self.float_placeholder_tensor = torch.tensor(-1.0, device=self.accelerator.device, dtype=torch.float) self.long_placeholder_tensor = torch.tensor(-1, device=self.accelerator.device, dtype=torch.long) def setup_batch_size_related_configs(self): """ batch_size-related configs are processed here. All this work is done here because it requires knowing the value of num_processes """ hparams = self.hparams if hparams.global_batch_size_ramp_up.start is not None: # case 1. global batch size ramp up # 1a. ramp up constraints if ( hparams.global_batch_size_ramp_up.finish is None or hparams.global_batch_size_ramp_up.increment is None or hparams.global_batch_size_ramp_up.samples is None ): raise ValueError( "When using batch size ramp up hparam config entries global_batch_size_ramp_up.start," " global_batch_size_ramp_up.finish, global_batch_size_ramp_up.increment and" " global_batch_size_ramp_up.samples have to be defined." ) # range checks ramp_up_range = hparams.global_batch_size_ramp_up.finish - hparams.global_batch_size_ramp_up.start if ramp_up_range < hparams.global_batch_size_ramp_up.increment: raise ValueError( f"{hparams.global_batch_size_ramp_up.start=} has to be smaller than" f" {hparams.global_batch_size_ramp_up.finish=}." ) if not (ramp_up_range / hparams.global_batch_size_ramp_up.increment).is_integer(): raise ValueError( f"({hparams.global_batch_size_ramp_up.finish=} -" f" {hparams.global_batch_size_ramp_up.start=}) /" f" {hparams.global_batch_size_ramp_up.increment=} has to be a whole number" ) if not ( hparams.global_batch_size_ramp_up.increment / (hparams.batch_size_per_gpu * self.accelerator.num_processes) ).is_integer(): raise ValueError( f"{hparams.global_batch_size_ramp_up.increment=} has to be a multiple of" f" {hparams.batch_size_per_gpu * self.accelerator.num_processes=}" ) if self.accelerator.is_main_process: logger.info( "Will ramp up global batch size from" f" {hparams.global_batch_size_ramp_up.start} to" f" {hparams.global_batch_size_ramp_up.finish}, in increments of" f" {hparams.global_batch_size_ramp_up.increment} every" f" {hparams.global_batch_size_ramp_up.samples} samples." ) # 1b. hparam.grad_acc_size constraints and derivations if hparams.grad_acc_size > 1: raise ValueError("When using batch size ramp up hparam.grad_acc_size must be None or 1.") hparams.grad_acc_size = hparams.global_batch_size_ramp_up.start / ( hparams.batch_size_per_gpu * self.accelerator.num_processes ) if not hparams.grad_acc_size.is_integer(): raise ValueError( f"{hparams.global_batch_size_ramp_up.start=} has to be a multiple of" f" {hparams.batch_size_per_gpu * self.accelerator.num_processes}" ) hparams.grad_acc_size = int(hparams.grad_acc_size) logger.info(f"Derived {hparams.grad_acc_size=}") # 1c. hparams.global_batch_size constraints and derivation if hparams.global_batch_size is not None: raise ValueError("When using batch size ramp up hparam.global_batch_size must be None.") # in the first run we start at global_batch_size == global_batch_size_ramp_up.start hparams.global_batch_size = hparams.global_batch_size_ramp_up.start else: # case 2. fixed global batch size or fixed grad_acc_size # 2a. constraints if hparams.grad_acc_size > 1: # when global_batch_size is used grad_acc_size will be derived automatically from global_batch_size and n_gpus if hparams.global_batch_size is not None and hparams.global_batch_size > 1: raise ValueError("set either hparams.grad_acc_size>1 or hparams.global_batch_size>1, but not both") if hparams.global_batch_size is not None and hparams.global_batch_size > 1: # 2b. have global_batch_size need to derive grad_acc_size hparams.grad_acc_size = hparams.global_batch_size / ( hparams.batch_size_per_gpu * self.accelerator.num_processes ) if not hparams.grad_acc_size.is_integer(): raise ValueError( f"The derived {hparams.grad_acc_size=} is not an integer," f" {hparams.global_batch_size=} / ({hparams.batch_size_per_gpu=} *" f" {self.accelerator.num_processes=})" ) hparams.grad_acc_size = int(hparams.grad_acc_size) logger.info(f"Derived {hparams.grad_acc_size=}") else: # 2c. have grad_acc_size need to derive global_batch_size hparams.global_batch_size = ( hparams.batch_size_per_gpu * hparams.grad_acc_size * self.accelerator.num_processes ) logger.info(f"Derived {hparams.global_batch_size=}") def update_gas_and_gbs(self, grad_acc_size_current, global_batch_size_current): """ Update m4, deepspeed and accelerate with the derived global_batch_size and grad_acc_size """ self.hparams.grad_acc_size = grad_acc_size_current self.hparams.global_batch_size = global_batch_size_current self.accelerator.gradient_accumulation_steps = grad_acc_size_current if is_deepspeed_used(): get_deepspeed_engine(self.accelerator).set_train_batch_size(global_batch_size_current) def _compatibility_checks(self): # BF16 requires cuda>=11 and nccl>=2.10.3 if self.accelerator.mixed_precision == "bf16": if version.parse(torch.version.cuda) < version.parse("11.0"): raise ValueError(f"mixed precision dtype BF16 requires cuda>=11, but got {torch.version.cuda}") if torch.cuda.nccl.version() < (2, 10, 3): raise ValueError( f"mixed precision dtype BF16 requires NCCL>=2.10.3, but got {'.'.join(torch.cuda.nccl.version())}" ) def _init_trackers(self): if self.hparams.wandb_enable and self.accelerator.is_main_process: # Initialize wandb_run_id if self.hparams.resume_run and self.hparams.wandb_run_id != "": pass else: if self.hparams.resume_run and not self.hparams.wandb_run_id != "": logger.warning( "** The run you are resuming was not logging into wandb. Therefore a new wandb_run_id has" " been generated **" ) self.hparams.wandb_run_id = wandb.util.generate_id() logger.info(f"** `wandb_run_id`: {self.hparams.wandb_run_id} **") # Initialize all trackers run_name = self.hparams.save_dir.name wdb_config = {} for k, v in vars(self.config).items(): if not hasattr(v, "__dict__"): wdb_config[k] = v continue for key, value in vars(v).items(): wdb_config[f"{k}-{key}"] = str(value) wandb_logger = wandb.init( config=wdb_config, id=self.hparams.wandb_run_id, resume="allow", project=self.hparams.wandb_project, entity=self.hparams.wandb_entity, name=run_name, allow_val_change=True, tags=self.hparams.wandb_tags, ) self.dummy_module = torch.nn.LayerNorm(1) wandb.watch( self.dummy_module, log="all", log_freq=self.hparams.wandb_log_freq * self.hparams.grad_acc_size, idx=0, # To avoid creating a new panel each we un-watch and then re-watch ) tb_run_name = "tb_run_" + self.hparams.wandb_run_id tensorboard_tracker = accelerate.tracking.TensorBoardTracker( run_name=tb_run_name, logging_dir=self.hparams.save_dir ) logger.info(f"** TensorBoardTracker logging into { self.hparams.save_dir / tb_run_name } **") self.accelerator.trackers = [tensorboard_tracker, wandb_logger] # Alert WB in replacement of slurm notifications and emails wandb.alert( title="Training has either started or resumed", text=( f"Run name = {run_name}, Jobid = {self.hparams.job_id}, Resuming = {self.hparams.resume_run}," f" Experiment folder = {str(self.hparams.save_dir)}" ), ) def _configure_optimizer_and_scheduler(self): """defines model optimizer and lr scheduler""" vl_optim = getattr(torch_optim, self.optim_param.vl_optim) if issubclass(vl_optim, torch_optim.AdamW): no_decay = self.optim_param.vl_optim_params.pop("no_decay", []) weight_decay = self.optim_param.vl_optim_params.pop("weight_decay", 0.0) optim_grouped_params = [ # Group: tunable parameters with weight decay { "params": [ p for n, p in self.vl_model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad ], "weight_decay": weight_decay, }, # Group: tunable parameters without weight decay at all { "params": [ p for n, p in self.vl_model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad ], "weight_decay": 0.0, }, ] vl_optim = vl_optim( optim_grouped_params, **self.optim_param.vl_optim_params, ) else: vl_optim = vl_optim( self.vl_model.parameters(), **self.optim_param.vl_optim_params, ) try: vl_scheduler_class = getattr(torch_optim.lr_scheduler, self.optim_param.vl_lr_scheduler) except AttributeError: vl_scheduler_class = getattr(transformers_optim, self.optim_param.vl_lr_scheduler) else: raise ValueError( f"Could not find {self.optim_param.vl_lr_scheduler} type of LR Scheduler in neither `torch.optim` nor" " `transformers.optimization`" ) vl_scheduler = vl_scheduler_class( optimizer=vl_optim, **self.optim_param.vl_lr_scheduler_params, ) self.vl_optim = vl_optim self.vl_scheduler = vl_scheduler def _prepare_register(self): """ Prepare model, optimizer and dataloader if necessary. Register the scheduler for checkpointing. """ if isinstance(self.train_loader.dataset, torch.utils.data.IterableDataset): # `dummy_dataloader`: trick as suggested here: https://discuss.huggingface.co/t/when-using-deepspeed-why-do-i-need-to-pass-dataloaders-to-the-accelerator-prepare/22432/2?u=victorsanh # In our IterableDataset, *WE* handle dispatch (instead of `Accelerate`) for each process ourselves as we need # better shuffling support. # =>> See `DATA_PROCESSING.md` for more information! dummy_dataloader = torch.utils.data.DataLoader( [0 for _ in range(20)], batch_size=self.hparams.batch_size_per_gpu ) # important: do note add lr scheduler in `prepare`. We are doing something non canonical # with our custom data loaders that leads to not being able to use the standard # arguments in the accelerator (typically split_batches which would have ensure the LR # was increased every x steps, where x>1 and would be the correct value for grad acc). # context: https://github.com/huggingface/m4/pull/386 self.vl_model, self.vl_optim, dummy_dataloader = self.accelerator.prepare( self.vl_model, self.vl_optim, dummy_dataloader ) else: ( self.vl_model, self.vl_optim, self.train_loader, self.val_loader, ) = self.accelerator.prepare(self.vl_model, self.vl_optim, self.train_loader, self.val_loader) self.accelerator.register_for_checkpointing(self.vl_scheduler) def _set_up_training(self): """ Prepare variables for trainings. """ if self.hparams.resume_run: # 1. resume train_logs = self.resume_param.train_logs curr_opt_step = self.resume_param.resume_opt_step curr_epoch = self.resume_param.resume_epoch gbs_running = self.resume_param.gbs_running # This check is necessary because the info is saved as json in the checkpoint # and when it is loaded back it is converted to a normal dictionary which can # fail downstream in case one of the dataset keys were missing in the saved info train_logs = self._check_default_dict_in_train_logs(train_logs) self.train_loader.load_state(self.resume_param.opt_step_dir / "resumable_states") if self.hparams.load_optimizer_states: self.accelerator.load_state(self.resume_param.accelerator_state_dir) else: # don't load the optimizer states and start with a fresh optimizer self.accelerator.load_state(self.resume_param.accelerator_state_dir, load_optimizer_states=False) validate_optim_states_are_reset(self) self.accelerator.wait_for_everyone() opt_step_is_saved = True eval_is_done = True else: # 2. non-resume (first run) train_logs = self._reset_train_logs(None) curr_opt_step = 0 curr_epoch = 0 opt_step_is_saved = False eval_is_done = False gbs_running = GlobalBatchSizeRampUpRunningParams( global_seen_samples=0, global_batch_size_current=self.hparams.global_batch_size, next_goal_samples=self.hparams.global_batch_size_ramp_up.samples, grad_acc_size_current=self.hparams.grad_acc_size, ) self.train_loader.reset_state() # rng = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(self.hparams.seed))) # self.main_rng_seed = rng.get_state() # self.main_rng_seed = rng.RandomState.get_state() self.update_gas_and_gbs(gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current) max_num_epochs = self.hparams.max_num_epochs try: num_batches = int(len(self.train_loader) // self.hparams.grad_acc_size) max_num_updates = min(self.hparams.max_num_opt_steps, num_batches) if max_num_epochs is not None: logger.info( "** Setting `max_num_updates` to `max_num_epochs * num_batches` since `max_num_epochs` " "was specified and `max_num_epochs * num_batches` is smaller than `max_num_updates`. **" ) max_num_updates = min(max_num_updates, max_num_epochs * num_batches) except TypeError: # For iterable datasets len(dataset) is not defined max_num_updates = self.hparams.max_num_opt_steps if self.hparams.max_num_opt_steps_this_run is not None: self.max_num_updates_this_run = min( max_num_updates, curr_opt_step + self.hparams.max_num_opt_steps_this_run ) else: self.max_num_updates_this_run = max_num_updates progress_columns = ( "[progress.description]{task.description}", BarColumn(), TaskProgressColumn(), "Time Elapsed:", TimeElapsedColumn(), "Steps Completed", MofNCompleteColumn(), ) return ( progress_columns, train_logs, max_num_epochs, max_num_updates, curr_opt_step, curr_epoch, opt_step_is_saved, eval_is_done, gbs_running, ) def _do_batch(self, batch, curr_opt_step, dataset_name=None, dataset_idx=None, validation=False): # Use effective max_num_images per batch. ie: if the max_num_images of this batch is 3, the pixel_values and image mask are truncated accordingly. # Same for max_height and max_width effective_max_num_images = max(batch["num_images"]) if effective_max_num_images > 0: images_heights = batch["pixel_attention_mask"][:, :, :, 0].sum(dim=-1) images_widths = batch["pixel_attention_mask"][:, :, 0].sum(dim=-1) effective_max_height = images_heights.max().int() effective_max_width = images_widths.max().int() batch["pixel_values"] = batch["pixel_values"][ :, :effective_max_num_images, :, :effective_max_height, :effective_max_width ] batch["pixel_attention_mask"] = batch["pixel_attention_mask"][ :, :effective_max_num_images, :effective_max_height, :effective_max_width ] else: # This case is a security check: if there are no images, then it should not appear in `batch` in the first place batch.pop("pixel_values", None) batch.pop("pixel_attention_mask", None) effective_max_num_tokens = max(batch["attention_mask"].sum(dim=-1)) batch["input_ids"] = batch["input_ids"][:, :effective_max_num_tokens] if "labels" in batch: batch["labels"] = batch["labels"][:, :effective_max_num_tokens] batch["attention_mask"] = batch["attention_mask"][:, :effective_max_num_tokens] batch = accelerate.utils.operations.send_to_device(batch, self.accelerator.device) num_images = batch["num_images"].sum() num_image_tokens = (batch["input_ids"] == self.image_token_id).sum() num_text_tokens = batch["num_text_tokens"].sum() total_tokens = batch["attention_mask"].numel() num_padding = total_tokens - batch["attention_mask"].sum() if "pixel_values" in batch: pixel_values_sum = batch["pixel_values"].sum() else: pixel_values_sum = torch.tensor(0.0, device=self.accelerator.device) image_to_text_ratio = torch.div(num_images, num_text_tokens) vl_output = self.vl_model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], pixel_values=batch["pixel_values"] if "pixel_values" in batch else None, pixel_attention_mask=batch["pixel_attention_mask"] if "pixel_attention_mask" in batch else None, labels=batch["labels"] if "labels" in batch else batch["input_ids"], ) per_token_loss = vl_output.loss if validation: return ( per_token_loss, num_images, num_text_tokens, image_to_text_ratio, num_padding, pixel_values_sum, ) else: if self.hparams.loss_weights_per_dataset is not None: per_token_loss *= self.hparams.loss_weights_per_dataset[dataset_idx] if self.optim_param.z_loss > 0.0: logits = vl_output.logits attention_mask = batch["attention_mask"] * (1 - (batch["input_ids"] == self.image_token_id).long()) log_z = torch.logsumexp(logits, dim=-1) * attention_mask z_loss = log_z**2 z_loss = z_loss.sum() / attention_mask.sum() combined_loss = per_token_loss + self.optim_param.z_loss * z_loss else: z_loss = torch.tensor(0.0, device=self.accelerator.device) combined_loss = per_token_loss sync_gradients = self.accelerator.sync_gradients deepspeed = hasattr(self.accelerator, "deepspeed_engine_wrapped") # accelerate's deepspeed `backward` calls `engine.step`, which is a problem if we want # to investigate things before step, so override with just a backward call and then call # `engine.step` along with `optim.step` a bit lower if deepspeed: self.accelerator.deepspeed_engine_wrapped.engine.backward(combined_loss) else: self.accelerator.backward(combined_loss) if sync_gradients: self.accelerator.clip_grad_norm_(self.vl_model.parameters(), self.hparams.grad_clip) if ( len(self.hparams.train_logging_grad_param_deepspeed) != 0 and (curr_opt_step + 1) % self.hparams.train_logging_grad_param_deepspeed_opt_steps == 0 ): self._log_deepspeed_training_stats(curr_opt_step=curr_opt_step) if deepspeed: self.accelerator.deepspeed_engine_wrapped.engine.step() self.vl_optim.step() # 1. sync_gradients is used for this dirty trick: https://github.com/huggingface/m4/pull/386 # 2. since we don't accelerate prepare the lr scheduler we need to manually skip it if # optimizer skipped (otherwise accelerate would do that for us) if sync_gradients and not self.accelerator.optimizer_step_was_skipped: self.vl_scheduler.step() self.vl_optim.zero_grad(set_to_none=True) tflops_per_batch_per_gpu = self.vl_model.get_model_tflops_per_batch_per_gpu( hparams=self.hparams, data_param=getattr(self.data_param, dataset_name), tokenizer=self.tokenizer, max_num_images=effective_max_num_images, max_num_tokens=effective_max_num_tokens, ).to(self.accelerator.device) # Reset batch return ( per_token_loss, z_loss, num_images, num_image_tokens, num_text_tokens, image_to_text_ratio, num_padding, pixel_values_sum, tflops_per_batch_per_gpu, ) def _log_deepspeed_training_stats(self, curr_opt_step): if self.hparams.job_id is not None: log_stats_file = self.hparams.save_dir / "logs" / f"{self.hparams.job_id}_logs_params_grads_stats.jsonl" else: log_stats_file = self.hparams.save_dir / "logs" / "logs_params_grads_stats.jsonl" beta1 = self.optim_param.vl_optim_params["betas"][0] beta2 = self.optim_param.vl_optim_params["betas"][1] eps = self.optim_param.vl_optim_params.get("eps", 1e-8) step = self.vl_optim.optimizer.state[list(self.vl_optim.optimizer.state.keys())[0]]["step"] bias_correction_1 = 1 / (1 - beta1**step) bias_correction_2 = 1 / (1 - beta2**step) for n, lp in self.vl_model.named_parameters(): self.accelerator.wait_for_everyone() if not lp.requires_grad: continue hp = self.safe_get_full_fp32_param(lp) exp_avg = self.safe_get_full_optimizer_state(lp, "exp_avg") exp_avg_sq = self.safe_get_full_optimizer_state(lp, "exp_avg_sq") hp_grad = self.safe_get_full_grad(lp) if not self.accelerator.is_main_process: continue if exp_avg_sq is not None and exp_avg is not None: effective_update = exp_avg * bias_correction_1 / (torch.sqrt(exp_avg_sq * bias_correction_2) + eps) else: effective_update = None grad_param_logs = { "name": n, "step": step.item(), **get_stats(hp, "hp"), **get_stats(exp_avg, "exp_avg"), **get_stats(exp_avg_sq, "exp_avg_sq"), **get_stats(hp_grad, "hp_grad"), **get_stats(effective_update, "effective_update"), } grad_param_format = { "step": "", "name": "", **get_stats_format("hp"), **get_stats_format("exp_avg"), **get_stats_format("exp_avg_sq"), **get_stats_format("hp_grad"), **get_stats_format("effective_update"), } if LoggingTypes.JSONL in self.hparams.train_logging_grad_param_deepspeed: with open(log_stats_file, "a") as f: f.write(json.dumps(grad_param_logs) + "\n") if LoggingTypes.WANDB in self.hparams.train_logging_grad_param_deepspeed and self.hparams.wandb_enable: self.accelerator.log({**grad_param_logs, **self._get_additional_step_logs()}, step=curr_opt_step + 1) if LoggingTypes.PRINT in self.hparams.train_logging_grad_param_deepspeed: log = "Grads and params stats: " log += self.format_print_logs(grad_param_logs, grad_param_format) print(log) def gather_metrics( self, local_metric_list: List[Dict[str, torch.Tensor]], placeholder_tensor: torch.Tensor, reduce_op_list, ds_name_suffix: str, ) -> List[Dict[str, torch.Tensor]]: """ Collating all metrics to gather into ONE call to `torch.distributed.all_gather` instead of doing one per metric x dataset_name. """ if self.accelerator.num_processes == 1: for local_metric in local_metric_list: for ds_name, tensor in local_metric.items(): if tensor is not None: local_metric[ds_name] = tensor.item() return local_metric_list dataset_names = sorted([f"{e.value}{ds_name_suffix}" for e in DatasetNames] + ["all"]) for local_metric in local_metric_list: for ds_name in dataset_names: if local_metric[ds_name] is None: local_metric[ds_name] = placeholder_tensor.clone() collated_local_metrics = torch.stack( [torch.stack([local_metric[ds_name] for ds_name in dataset_names]) for local_metric in local_metric_list] ) # Size len(local_metric_list) x len(dataset_names) broadcasted_placeholder_tensor = torch.stack( [torch.stack([placeholder_tensor for _ in dataset_names]) for _ in local_metric_list] ) # Size len(local_metric_list) x len(dataset_names) output_objects = [broadcasted_placeholder_tensor.clone() for _ in range(self.accelerator.num_processes)] torch.distributed.all_gather(output_objects, collated_local_metrics) gathered_metrics = torch.stack( output_objects ) # Size num_processes x len(local_metric_list) x len(dataset_names) gathered_metric_list = [] for metric_idx, (_, reduce_op) in enumerate(zip(local_metric_list, reduce_op_list)): result = {} for ds_idx, ds_name in enumerate(dataset_names): metrics = gathered_metrics[:, metric_idx, ds_idx] metrics = metrics[metrics != placeholder_tensor] if metrics.numel() == 0: result[ds_name] = None else: result[ds_name] = reduce_op(metrics).item() gathered_metric_list.append(result) return gathered_metric_list def _update_logs( self, curr_opt_step, curr_epoch, global_batch_size_current, train_logs, per_token_loss, z_loss, num_tokens, num_images, num_image_tokens, image_to_text_ratio, num_padding, fwd_bwd_time, pixel_values, tflops_per_batch_per_gpu, total_energy_delta_per_gpu, dataset_name, ds_name_suffix="", ): def init_dict(): return {f"{e.value}{ds_name_suffix}": None for e in DatasetNames} local_per_token_loss = init_dict() local_z_loss = init_dict() local_num_tokens = init_dict() local_num_images = init_dict() local_num_image_tokens = init_dict() local_image_to_text_ratio = init_dict() local_fwd_bwd_time = init_dict() local_pixel_values = init_dict() local_tflops_per_batch_per_gpu = init_dict() local_total_energy_delta_per_gpu = init_dict() local_num_padding = init_dict() local_num_batches = init_dict() for key_name in [f"{dataset_name}{ds_name_suffix}", "all"]: local_per_token_loss[key_name] = per_token_loss local_z_loss[key_name] = z_loss local_num_tokens[key_name] = num_tokens local_num_images[key_name] = num_images local_num_image_tokens[key_name] = num_image_tokens local_image_to_text_ratio[key_name] = image_to_text_ratio local_fwd_bwd_time[key_name] = fwd_bwd_time local_pixel_values[key_name] = pixel_values local_tflops_per_batch_per_gpu[key_name] = tflops_per_batch_per_gpu local_total_energy_delta_per_gpu[key_name] = total_energy_delta_per_gpu local_num_padding[key_name] = num_padding local_num_batches[key_name] = torch.tensor(1.0, device=self.accelerator.device, dtype=torch.long) [ gathered_per_token_loss, gathered_z_loss, gathered_image_to_text_ratio, gathered_fwd_bwd_time, gathered_pixel_values, gathered_tflops_per_batch_per_gpu, gathered_total_energy_delta_per_gpu, ] = self.gather_metrics( local_metric_list=[ local_per_token_loss, local_z_loss, local_image_to_text_ratio, local_fwd_bwd_time, local_pixel_values, local_tflops_per_batch_per_gpu, local_total_energy_delta_per_gpu, ], reduce_op_list=[torch.sum, torch.sum, torch.mean, torch.sum, torch.sum, torch.sum, torch.sum], placeholder_tensor=self.float_placeholder_tensor, ds_name_suffix=ds_name_suffix, ) [ gathered_num_padding, gathered_num_tokens, gathered_num_batches, gathered_num_images, gathered_num_image_tokens, ] = self.gather_metrics( local_metric_list=[ local_num_padding, local_num_tokens, local_num_batches, local_num_images, local_num_image_tokens, ], reduce_op_list=[torch.sum, torch.sum, torch.sum, torch.sum, torch.sum], placeholder_tensor=self.long_placeholder_tensor, ds_name_suffix=ds_name_suffix, ) for ds_name in local_per_token_loss.keys(): for metric_name, new_value in [ ("per_token_loss_acc", gathered_per_token_loss[ds_name]), ("z_loss_acc", gathered_z_loss[ds_name]), ("num_images", gathered_num_images[ds_name]), ("num_image_tokens", gathered_num_image_tokens[ds_name]), ("num_tokens", gathered_num_tokens[ds_name]), ("num_padding", gathered_num_padding[ds_name]), ("pixel_values_sum", gathered_pixel_values[ds_name]), ("tflop_counter_since_training_logged", gathered_tflops_per_batch_per_gpu[ds_name]), ("fwd_bwd_time_since_training_logged", gathered_fwd_bwd_time[ds_name]), ("total_energy_delta_since_training_logged", gathered_total_energy_delta_per_gpu[ds_name]), ("fwd_bwd_time", gathered_fwd_bwd_time[ds_name]), ("tflop_counter", gathered_tflops_per_batch_per_gpu[ds_name]), ("num_per_device_batches_since_training_logged", gathered_num_batches[ds_name]), ("num_per_device_batches", gathered_num_batches[ds_name]), ("num_per_device_batches_in_curr_epoch", gathered_num_batches[ds_name]), ]: if new_value is None: continue if train_logs[metric_name][ds_name] is None: train_logs[metric_name][ds_name] = new_value else: train_logs[metric_name][ds_name] += new_value if gathered_image_to_text_ratio[ds_name] is not None: train_logs["image_to_text_ratio"][ds_name] = gathered_image_to_text_ratio[ds_name] if gathered_fwd_bwd_time[ds_name] is not None: train_logs["tflops_acc"][ds_name] = ( train_logs["tflop_counter"][ds_name] / train_logs["fwd_bwd_time"][ds_name] ) train_logs["num_batches_since_training_logged"]["all"] += 1 train_logs["num_batches"]["all"] += 1 train_logs["num_batches_in_curr_epoch"]["all"] += 1 train_logs["lr"] = self.vl_scheduler.get_last_lr()[0] train_logs["num_opt_steps"] = curr_opt_step train_logs["num_epochs"] = curr_epoch train_logs["global_batch_size_current"] = global_batch_size_current return train_logs def _update_datasets_states(self, dataset_idx, dataset_state): # TODO: This step will go away in future PRs. The dataloader already knows the state when it # sends it to the trainer. There is no need to send it to trainer and send it back. Let's # simplify this as well in the future self.train_loader.update_state(dataset_idx, dataset_state) def _get_additional_step_logs(self): if self.config.hparams.job_id is not None: return {"job_id": self.config.hparams.job_id, "commit": self.config.hparams.repo_commit_id} else: return {"commit": self.config.hparams.repo_commit_id} def format_print_logs(self, dict_logs, keys_known_formats, skip_keys=[]): """ compact formatting of the logs with pre-specified formatter for each log entry, plus a catch-all if new log entries are added but forgotten to be added in keys_known_formats the keys order is the one that controls how the logs are printed (py37+). even if there is no formatter there is still an empty value entry here as it tells use the order of keys. """ log = "" for key in keys_known_formats.keys(): if key in dict_logs: if isinstance(dict_logs[key], dict): for sub_key in dict_logs[key].keys(): prefix = f"{key}" if sub_key != "all": if LoggingTypes.PRINT not in self.hparams.train_logging_per_dataset_info: continue prefix += f"/{sub_key}" log += f"{prefix}: {dict_logs[key][sub_key]:{keys_known_formats[key]}} | " else: log += f"{key}: {dict_logs[key]:{keys_known_formats[key]}} | " # in case some new log entries were added that don't yet have the formatter string we dump them as is for key in set(dict_logs.keys() - set(skip_keys) - set(keys_known_formats.keys())): if key in dict_logs: log += f"{key}: {dict_logs[key]} | " return log def format_jsonl_logs(self, dict_logs): """ Similar to format_print_logs but for jsonl logs """ log = {} for key in dict_logs: # We don't want to log the accumulated values if "_acc" in key: continue elif isinstance(dict_logs[key], dict): for sub_key in dict_logs[key].keys(): prefix = f"{key}" if sub_key != "all": if LoggingTypes.JSONL not in self.hparams.train_logging_per_dataset_info: continue prefix += f"/{sub_key}" log[prefix] = dict_logs[key][sub_key] else: log[key] = dict_logs[key] return log def format_val_logs(self, val_logs, logger_type=LoggingTypes.PRINT): keys_known_formats = { "val_per_token_loss": ".4f", "val_num_images": "", "val_num_tokens": "", "val_num_padding": "", "val_image_to_text_ratio": ".4f", } if logger_type == LoggingTypes.PRINT: return self.format_print_logs(val_logs, keys_known_formats) elif logger_type == LoggingTypes.JSONL: return self.format_jsonl_logs(val_logs) else: raise ValueError(f"Unknown logger type: {logger_type}") def format_train_logs(self, train_logs, logger_type=LoggingTypes.PRINT) -> Union[str, Dict]: keys_known_formats = { "per_token_loss": ".4f", "lr": ".3E", "global_batch_size": "", "num_tokens": "", "num_images": "", "num_image_tokens": "", "num_padding": "", "fwd_bwd_time": ".1f", "image_to_text_ratio": ".4f", "num_batches": "", "num_batches_in_curr_epoch": "", "num_batches_since_training_logged": "", "num_per_device_batches": "", "num_per_device_batches_in_curr_epoch": "", "num_per_device_batches_since_training_logged": "", "tflops": ".1f", "watt/s": ".1f", "fwd_bwd_time_since_training_logged": ".1f", "num_epochs": "", "num_opt_steps": "", "z_loss": ".4f", "pixel_values_sum": ".5E", "tflop_counter": ".3E", "tflops_acc": ".1f", } # intermediary state accumulate keys skip_keys = [ "per_token_loss_acc", "z_loss_acc", "tflop_counter_since_training_logged", "num_batches_since_training_logged", "num_per_device_batches_since_training_logged", "total_energy_delta_since_training_logged", ] if logger_type == LoggingTypes.PRINT: return self.format_print_logs(train_logs, keys_known_formats, skip_keys) elif logger_type == LoggingTypes.JSONL: return self.format_jsonl_logs(train_logs) else: raise ValueError(f"Unknown logger type: {logger_type}") def _log_training(self, curr_opt_step, train_task, train_logs): for key in train_logs["per_token_loss_acc"].keys(): if train_logs["num_per_device_batches_since_training_logged"][key] is not None: train_logs["per_token_loss"][key] = ( train_logs["per_token_loss_acc"][key] / train_logs["num_per_device_batches_since_training_logged"][key] ) train_logs["z_loss"][key] = ( train_logs["z_loss_acc"][key] / train_logs["num_per_device_batches_since_training_logged"][key] ) else: train_logs["per_token_loss"][key] = None train_logs["z_loss"][key] = None if train_logs["fwd_bwd_time_since_training_logged"][key] is not None: train_logs["tflops"][key] = ( train_logs["tflop_counter_since_training_logged"][key] / train_logs["fwd_bwd_time_since_training_logged"][key] ) train_logs["watt/s"][key] = ( train_logs["total_energy_delta_since_training_logged"][key] / train_logs["fwd_bwd_time_since_training_logged"][key] ) else: train_logs["tflops"][key] = None train_logs["watt/s"][key] = None if self.accelerator.is_main_process: print_log = "" progress = f"{str(MofNCompleteColumn().render(train_task)):>12} {TaskProgressColumn().render(train_task)}" print_log += f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] iteration: {progress} | " elapsed_time = TimeElapsedColumn().render(train_task) print_log += f"elapsed time: {elapsed_time} | " print_log += self.format_train_logs(train_logs, logger_type=LoggingTypes.PRINT) # TODO: Allow mem usage to be logged according to LogginTypes passed in hparams if self.hparams.train_log_mem_usage: print_log += mem_usage_formatted(LoggingTypes.PRINT) print(print_log) jsonl_logs = { "iteration": progress.strip(), "elapsed_time": str(elapsed_time), "set": "train", } jsonl_logs.update(self.format_train_logs(train_logs, logger_type=LoggingTypes.JSONL)) if self.hparams.train_log_mem_usage: jsonl_logs.update(mem_usage_formatted(LoggingTypes.JSONL)) if self.hparams.job_id is not None: log_jsonl_file = self.hparams.save_dir / "logs" / f"{self.hparams.job_id}_logs.jsonl" else: log_jsonl_file = self.hparams.save_dir / "logs" / "logs.jsonl" log_jsonl_file.parent.mkdir(parents=True, exist_ok=True) with open(log_jsonl_file, "a") as f: f.write(json.dumps(jsonl_logs) + "\n") if self.hparams.wandb_enable: filtered_train_logs = train_logs if LoggingTypes.WANDB not in self.hparams.train_logging_per_dataset_info: filtered_train_logs = {} for key in train_logs.keys(): if isinstance(train_logs[key], dict): filtered_train_logs[key] = train_logs[key]["all"] else: filtered_train_logs[key] = train_logs[key] # remove nested None values as wandb doesn't support them filtered_train_logs = {k: v for k, v in filtered_train_logs.items() if v is not None} for k, v in filtered_train_logs.items(): if isinstance(v, dict): filtered_train_logs[k] = {k2: v2 for k2, v2 in v.items() if v2 is not None} self.accelerator.log({**filtered_train_logs, **self._get_additional_step_logs()}, step=curr_opt_step) train_logs = self._reset_train_logs(train_logs) return train_logs def _log_activations(self, curr_opt_step): if not self.activation_tracker.jsonl_stats: return if LoggingTypes.JSONL in self.hparams.train_logging_activations: if self.hparams.job_id is not None: log_activations_filename = ( self.hparams.save_dir / "logs" / f"{self.hparams.job_id}_logs_activations.jsonl" ) else: log_activations_filename = self.hparams.save_dir / "logs" / "logs_activations.jsonl" self.activation_tracker.dump_stats(log_activations_filename, curr_opt_step=curr_opt_step) # if LoggingTypes.WANDB in self.hparams.train_logging_activations and self.hparams.wandb_enable: # for stats in self.activation_tracker.jsonl_stats: # self.accelerator.log({**stats, **self._get_additional_step_logs()}, step=curr_opt_step) if LoggingTypes.PRINT in self.hparams.train_logging_activations: for stats in self.activation_tracker.jsonl_stats: stats["step"] = curr_opt_step activation_format = { k: "" if ("nonzero" in k or "step" in k or "name" in k or "type" in k or "batches" in k) else "e" for k in stats.keys() } log = "Activation stats: " log += self.format_print_logs(stats, activation_format) print(log) self.activation_tracker.reset_jsonl_stats() def _check_kill_switch(self): if self.hparams.kill_switch_path is not None and self.hparams.kill_switch_path.exists(): self.kill_switch_activated = True def _check_jz_time_and_memory(self, curr_opt_step): # From https://github.com/wandb/wandb/blob/9c777265f8cea1eaeb0407dd37ab889aeea81114/wandb/sdk/internal/stats.py#L263 if self.accelerator.is_local_main_process: self.memory_value = torch.tensor(psutil.virtual_memory().percent).to(self.accelerator.device) else: self.memory_value = torch.tensor(0.0).to(self.accelerator.device) self.accelerator.wait_for_everyone() memory_value_max = self.accelerator.gather(self.memory_value) memory_value_max = memory_value_max.max().item() self.memory_explosion = memory_value_max >= _MEMORY_EXPLOSION_THRESHOLD if self.hparams.jz_job_time_sec is not None: if self.accelerator.is_main_process: overall_time = time.time() - self.hparams.jz_start_time self.jz_training_time_over[0] = overall_time >= self.hparams.jz_job_time_sec self.accelerator.wait_for_everyone() accelerate.utils.broadcast_object_list(self.jz_training_time_over) if self.accelerator.is_main_process and self.hparams.wandb_enable: system_metrics_logs = self._get_system_metrics_logs(memory_value_max) self.accelerator.log({**system_metrics_logs, **self._get_additional_step_logs()}, step=curr_opt_step) def _check_sigterm_signal(self): if self.sigterm_listener.kill_now: self.sigterm_signal_received = True def _save( self, train_logs, curr_opt_step, curr_epoch, gbs_running, ): self.accelerator.wait_for_everyone() # create directory and file names self.last_opt_step_dir = self.hparams.save_dir / f"opt_step-{curr_opt_step}" # Make directory for this step self.last_opt_step_dir.mkdir(parents=True, exist_ok=True) self.train_loader.save_state(self.last_opt_step_dir / "resumable_states") # XXX: why is there a hardcoded accelerator_state path? should be coming from config, no? self.accelerator.save_state(self.last_opt_step_dir / "accelerator_state") if self.accelerator.is_main_process: # Save model and accelerator state unwrapped_model = self.accelerator.unwrap_model(self.vl_model) # fix the model class name to be of VLOOOM type the first time it's saved unwrapped_model.config.architectures = [unwrapped_model.__class__.__name__] # deepspeed doesn't need the overhead of gathering the model from all gpus if not is_deepspeed_zero3_used(): if self.hparams.use_lora: unwrapped_model.save_pretrained(self.last_opt_step_dir / "unwrapped_adapter") # Manual unloading with a simple PeftMixin to avoid having to deal with PeftModel state dict base_model = lora_unload(copy.deepcopy(unwrapped_model)) # Save pretrained with _hf_peft_config_loaded=True will save the adapters only. So we set it manually to False base_model._hf_peft_config_loaded = False base_model.save_pretrained(self.last_opt_step_dir / "unwrapped_model") del base_model else: unwrapped_model.save_pretrained(self.last_opt_step_dir / "unwrapped_model") else: # For deepspeed, `save_checkpoint` done by accelerate takes care of saving the model in a # special format per gpu which on resume will be used to load the model - so we don't need to # save pytorch state_dict separately, which can be costly or impossible if there not enough CPU RAM. # We only need to save the config unwrapped_model.config.save_pretrained( self.last_opt_step_dir / "unwrapped_model", ) if self.hparams.use_lora: unwrapped_model.peft_config["default"].save_pretrained( self.last_opt_step_dir / "unwrapped_adapter", ) # Save tokenizer directly into the same dir self.tokenizer.save_pretrained( self.last_opt_step_dir / "tokenizer", ) # infos to resume run at this step data = { "train_logs": train_logs, "wandb_run_id": self.hparams.wandb_run_id, "seed": self.hparams.seed, "resume_opt_step": curr_opt_step, "resume_epoch": curr_epoch, "gbs_running": gbs_running, } with open(self.last_opt_step_dir / "resume_run_infos.json", "w") as fp: # json.dump(data, fp, indent=2) json.dump(data, fp, indent=2, cls=JSONEncoderForDataclasses) self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: # mark this checkpoint as finished - needed for async slurm jobs like s3 uploader latest_path_finished = self.last_opt_step_dir / "finished-saving" latest_path_finished.touch() # mark which is latest saved checkpoint for correct resume latest_path = self.hparams.save_dir / "latest_opt_step_dir" with open(latest_path, "w") as fd: fd.write(str(self.last_opt_step_dir)) logger.info(f"** Saving finished at `{self.last_opt_step_dir}` **") if self.accelerator.is_main_process and self.hparams.upload_to_s3: # We keep around the last checkpoint (which was saved just above) locally, and delete the previous to last one. locally_present_saved_steps_inds = [ int(os.path.split(dir)[-1].split("opt_step-")[-1]) for dir in self.hparams.save_dir.iterdir() if (dir.is_dir() and "opt_step" in str(dir)) ] if len(locally_present_saved_steps_inds) >= 2: previous_to_last_saved_step = sorted(locally_present_saved_steps_inds)[-2] previous_to_last_folder = f"opt_step-{previous_to_last_saved_step}" else: previous_to_last_folder = "" # Subprocess command is inspired from https://stackoverflow.com/questions/5772873/python-spawn-off-a-child-subprocess-detach-and-exit/64145368#64145368 # `stdout` and `stderr` are supressed to avoid polluting the logs with the output of the command. cmd = ( str(Path(__file__).resolve().parents[2]) + "/experiments/pretraining/vloom/common/sync_and_upload.sh", os.path.split(self.hparams.save_dir)[0], os.path.split(self.hparams.save_dir)[1], previous_to_last_folder, "opt_step-" + str(curr_opt_step), ) subprocess.Popen(cmd, start_new_session=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) def _save_batch(self, batch, curr_idx): dir_path = self.hparams.save_dir / "batches" dir_path.mkdir(parents=True, exist_ok=True) with open(dir_path / f"batch_idx_{curr_idx}_proc_{self.accelerator.process_index}.pkl", "wb") as file: pickle.dump(batch, file) def _check_if_training_is_over(self, curr_opt_step, max_num_updates): self._check_kill_switch() self._check_jz_time_and_memory(curr_opt_step=curr_opt_step) self._check_sigterm_signal() finished_training = True if self.kill_switch_activated: logger.info(f"** Kill switch activated (Don't forget to remove {self.hparams.kill_switch_path}) **") elif self.sigterm_signal_received: logger.info("** SIGTERM signal received. Please restart training **") elif self.jz_training_time_over[0]: logger.info("** Training time is over **") elif self.memory_explosion: logger.info("** CPU memory is close to explosion. Please restart training **") elif curr_opt_step >= max_num_updates: logger.info("** Maximum number of steps has been reached for this training **") elif curr_opt_step >= self.max_num_updates_this_run: logger.info("** Maximum number of steps has been reached for this run **") else: finished_training = False return finished_training def _check_if_training_is_over_and_maybe_save_model( self, curr_opt_step, curr_epoch, gbs_running, max_num_updates, train_logs, opt_step_is_saved, ): # check if should finish the training finished_training = self._check_if_training_is_over(curr_opt_step, max_num_updates) # save the model # 1. either because it's a scheduled saving is_opt_to_save_model = curr_opt_step != 0 and curr_opt_step % self.hparams.train_saving_opt_steps == 0 # 2. it was requested via the save switch if self.hparams.save_switch_path is not None and self.hparams.save_switch_path.exists(): is_opt_to_save_model = True logger.info("** Save switch activated - forcing checkpoint save **") self.hparams.save_switch_path.unlink() if not opt_step_is_saved and (finished_training or is_opt_to_save_model): self._save( train_logs, curr_opt_step, curr_epoch, gbs_running, ) opt_step_is_saved = True return finished_training, opt_step_is_saved def _reset_train_logs(self, train_logs): if train_logs is None: train_logs = {} for metric_name, default_value_fn in METRICS_TO_DEFAULT_VALUE_FN.items(): if metric_name in METRICS_TO_RESET_AFTER_LOGGING: continue train_logs[metric_name] = default_value_fn() # reset counters that get zeroed after every log event train_logs.update( {metric_name: METRICS_TO_DEFAULT_VALUE_FN[metric_name]() for metric_name in METRICS_TO_RESET_AFTER_LOGGING} ) return train_logs def _check_default_dict_in_train_logs(self, train_logs): for metric_name, default_value_fn in METRICS_TO_DEFAULT_VALUE_FN.items(): value = default_value_fn() # if the metric was initialized as a defaultdict, we need to convert it again to a defaultdict if not isinstance(value, defaultdict): continue value.update(train_logs.get(metric_name, {})) train_logs[metric_name] = value return train_logs def _end_of_epoch_reset_train_logs(self, train_logs): if train_logs is None: raise ValueError("`train_logs` should not be `None` at the end of an epoch.") train_logs.update( { metric_name: METRICS_TO_DEFAULT_VALUE_FN[metric_name]() for metric_name in ["num_batches_in_curr_epoch", "num_per_device_batches_in_curr_epoch"] } ) return train_logs def _do_validation(self, progress, curr_opt_step): try: val_len = len(self.val_loader) except TypeError: val_len = None self.vl_model.eval() val_loop = progress.add_task( f"[cyan]Validation step-{curr_opt_step}", disable=not self.accelerator.is_main_process, total=val_len, visible=False, ) curr_val_task = progress.tasks[-1] starter_dict = {e.value: 0 for e in DatasetNames} starter_dict["all"] = 0 val_per_token_loss_acc = copy.deepcopy(starter_dict) val_steps = copy.deepcopy(starter_dict) val_num_images = copy.deepcopy(starter_dict) val_num_tokens = copy.deepcopy(starter_dict) val_num_padding = copy.deepcopy(starter_dict) val_image_to_text_ratio = {e.value: [] for e in DatasetNames} val_image_to_text_ratio["all"] = [] for _, dataset_name, _, batch in self.val_loader: with torch.no_grad(): ( curr_val_per_token_loss, curr_val_num_images, curr_val_num_tokens, curr_val_image_to_text_ratio, curr_val_num_padding, _, ) = self._do_batch(batch, curr_opt_step, validation=True) val_per_token_loss_acc["all"] += curr_val_per_token_loss val_num_images["all"] += curr_val_num_images val_num_tokens["all"] += curr_val_num_tokens val_num_padding["all"] += curr_val_num_padding val_image_to_text_ratio["all"].append(curr_val_image_to_text_ratio) val_steps["all"] += 1 val_per_token_loss_acc[dataset_name] += curr_val_per_token_loss val_num_images[dataset_name] += curr_val_num_images val_num_tokens[dataset_name] += curr_val_num_tokens val_num_padding[dataset_name] += curr_val_num_padding val_image_to_text_ratio[dataset_name].append(curr_val_image_to_text_ratio) val_steps[dataset_name] += 1 progress.update(val_loop, advance=1) if ( curr_val_task.completed % self.hparams.val_inline_logging_opt_steps == 0 and self.accelerator.is_main_process ): logger.info( "Validation" f" step-{curr_opt_step} state:{TaskProgressColumn().render(curr_val_task)} Time" f" Elapsed: {TimeElapsedColumn().render(curr_val_task)} Steps" f" Completed:{MofNCompleteColumn().render(curr_val_task)}" ) self.vl_model.train() return ( val_steps, val_per_token_loss_acc, val_num_images, val_num_tokens, val_image_to_text_ratio, val_num_padding, ) def _log_validation( self, val_steps, curr_opt_step, val_per_token_loss_acc, val_num_images, val_num_tokens, val_image_to_text_ratio, val_num_padding, ): def convert_to_tensor(x): if not torch.is_tensor(x): return torch.tensor(x, device=self.accelerator.device) return x for key, value in val_image_to_text_ratio.items(): if len(value) != 0: val_image_to_text_ratio[key] = sum(value) / len(value) else: val_image_to_text_ratio[key] = 0.0 gathered_val_per_token_loss = {} gathered_val_num_images = {} gathered_val_num_tokens = {} gathered_val_num_padding = {} gathered_val_image_to_text_ratio = {} gathered_val_steps = {} for key in val_image_to_text_ratio.keys(): ( gathered_val_per_token_loss[key], gathered_val_num_images[key], gathered_val_num_tokens[key], gathered_val_num_padding[key], gathered_val_image_to_text_ratio[key], gathered_val_steps[key], ) = self.accelerator.gather( ( convert_to_tensor(val_per_token_loss_acc[key]), convert_to_tensor(val_num_images[key]), convert_to_tensor(val_num_tokens[key]), convert_to_tensor(val_num_padding[key]), convert_to_tensor(val_image_to_text_ratio[key]), convert_to_tensor(val_steps[key]), ) ) # No overall steps so we should skip this if gathered_val_steps[key].sum() == 0: gathered_val_per_token_loss.pop(key) gathered_val_num_images.pop(key) gathered_val_num_tokens.pop(key) gathered_val_num_padding.pop(key) gathered_val_image_to_text_ratio.pop(key) continue gathered_val_per_token_loss[key] = ( gathered_val_per_token_loss[key].sum().item() / gathered_val_steps[key].sum().item() ) gathered_val_num_images[key] = gathered_val_num_images[key].sum().item() gathered_val_num_tokens[key] = gathered_val_num_tokens[key].sum().item() gathered_val_num_padding[key] = gathered_val_num_padding[key].sum().item() gathered_val_image_to_text_ratio[key] = ( gathered_val_image_to_text_ratio[key][gathered_val_steps[key] != 0.0].mean().item() ) val_logs = { "val_per_token_loss": gathered_val_per_token_loss, "val_num_images": gathered_val_num_images, "val_num_tokens": gathered_val_num_tokens, "val_num_padding": gathered_val_num_padding, "val_image_to_text_ratio": gathered_val_image_to_text_ratio, } if self.accelerator.is_main_process: print(f"Validation logs: {self.format_val_logs(val_logs, LoggingTypes.PRINT)}") jsonl_logs = {"current step": curr_opt_step, "set": "validation"} jsonl_logs.update(self.format_val_logs(val_logs, LoggingTypes.JSONL)) if self.hparams.job_id is not None: log_jsonl_file = self.hparams.save_dir / "logs" / f"{self.hparams.job_id}_logs.jsonl" else: log_jsonl_file = self.hparams.save_dir / "logs" / "logs.jsonl" with open(log_jsonl_file, "a") as f: f.write(json.dumps(jsonl_logs) + "\n") if self.hparams.wandb_enable: self.accelerator.log({**val_logs, **self._get_additional_step_logs()}, step=curr_opt_step) def train(self, maybe_torch_profile_scheduler=None): # timing_break_down = self.accelerator.is_main_process and self.hparams.timing_break_down if self.accelerator.is_main_process: logger.info(f"** Global main process pid={os.getpid()} **") elif self.accelerator.is_local_main_process: logger.info(f"** Local main process pid={os.getpid()} **") # -------------------- # Set-up everything needed for training # -------------------- ( progress_columns, train_logs, max_num_epochs, max_num_updates, curr_opt_step, curr_epoch, opt_step_is_saved, eval_is_done, gbs_running, ) = self._set_up_training() # -------------------- # Training loop # -------------------- self.vl_model.train() pynvml_handle = pynmvl_handle(self.accelerator) with Progress(*progress_columns, refresh_per_second=5, disable=True) as progress: progress_bar = progress.add_task( "[red]Training", disable=not self.accelerator.is_main_process, total=max_num_updates, visible=False ) train_task = progress.tasks[-1] progress.update(progress_bar, advance=curr_opt_step) finished_training = False training_logged = True timer = DeviceAgnosticTimer() timer.start() if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: timer2 = Timer() time_deltas = {} while not finished_training: self.train_loader.set_epoch(curr_epoch) # Handle resume based on `realtime_processing` if self.hparams.resume_run: if not self.data_param.realtime_processing: # When preparing, `accelerate` has a bug that overwrites the top-level `sampler`... but this is aesthetic! # The "actual" sampler that the DataLoader uses is the one that's tucked inside the `batch_sampler` # attribute when "single-process" (world_size = 1), or the `batch_sampler.batch_sampler` when # "multi-process" (world_size > 1) -- both of which point to our specialized ResumableSampler! # # This is pretty annoying and nuanced; should PR into `accelerate` to fix this... logger.warning("NOT realtime processing has not been extensively tested yet") if self.accelerator.num_processes == 1: # TODO; Fix this self.train_loader.batch_sampler.sampler.set_state(self.train_loader.get_resume_state(0)) else: # TODO :: This is actually broken and not respected by `accelerate` - fails! # self.train_loader.batch_sampler.batch_sampler.sampler.set_state( # self.resumable_state.get_resume_state() # ) raise NotImplementedError("Map Dataset Resume w/ DDP not yet implemented!") else: self.train_loader.load_resume_states() if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: timer2.start() for curr_idx, (dataset_idx, dataset_name, dataset_state, batch) in enumerate(self.train_loader): # -------------------- # Check if the training is over and if so may be save the model before training batch # -------------------- if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: time_deltas["dl"] = timer2.delta() finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model( curr_opt_step, curr_epoch, gbs_running, max_num_updates, train_logs, opt_step_is_saved, ) if finished_training: break # -------------------- # Activate/deactivate hooks for logging activations or not # -------------------- # We are logging everything at `curr_opt_step`, but `curr_opt_step` is incremented a few lines later, so activating # the activation tracking hooks based on `curr_opt_step + 1`. See `_log_activations` for more details. if self.activation_tracker: if (curr_opt_step + 1) % self.hparams.train_logging_activations_opt_steps == 0 and ( curr_idx + 1 ) % self.hparams.grad_acc_size == 0: self.activation_tracker.activate_hooks() else: self.activation_tracker.deactivate_hooks() if ( self.hparams.save_batch_max_idx is not None and self.hparams.save_batch_min_idx is not None and curr_idx <= self.hparams.save_batch_max_idx and curr_idx >= self.hparams.save_batch_min_idx ): self._save_batch(batch, curr_idx) # right before fwd-bwd-step if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: time_deltas["between_dl_fwd_bwd"] = timer2.delta() total_energy_start = pynvml_get_total_energy_in_joules(pynvml_handle) with self.accelerator.accumulate(self.vl_model): ( per_token_loss, z_loss, num_images, num_image_tokens, num_tokens, image_to_text_ratio, num_padding, pixel_values_sum, tflops_per_batch_per_gpu, ) = self._do_batch( batch, curr_opt_step=curr_opt_step, dataset_name=dataset_name, dataset_idx=dataset_idx, validation=False, ) # right after fwd-bwd-step if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: time_deltas["fwd-bwd-step"] = timer2.delta() fwd_bwd_time = timer.stop() fwd_bwd_time = torch.tensor(fwd_bwd_time, device=self.accelerator.device) total_energy_delta_per_gpu = torch.tensor( pynvml_get_total_energy_in_joules(pynvml_handle) - total_energy_start, device=self.accelerator.device, ) if (curr_idx + 1) % self.hparams.grad_acc_size == 0: curr_opt_step += 1 opt_step_is_saved, eval_is_done, training_logged = False, False, False progress.update(progress_bar, advance=1) # -------------------- # Update logs # -------------------- train_logs = self._update_logs( curr_opt_step, curr_epoch, gbs_running.global_batch_size_current, train_logs, per_token_loss, z_loss, num_tokens, num_images, num_image_tokens, image_to_text_ratio, num_padding, fwd_bwd_time, pixel_values_sum, tflops_per_batch_per_gpu, total_energy_delta_per_gpu, dataset_name, self.hparams.train_logging_per_dataset_suffix, ) timer = DeviceAgnosticTimer() timer.start() # -------------------- # Update datasets states # -------------------- self._update_datasets_states(dataset_idx, dataset_state) # -------------------- # Log training infos # -------------------- if curr_opt_step % self.hparams.train_logging_opt_steps == 0 and not training_logged: train_logs = self._log_training(curr_opt_step, train_task, train_logs) training_logged = True # -------------------- # Log activations # -------------------- if self.activation_tracker: batch_idx = train_logs["num_batches"]["all"] self.activation_tracker.fill_in_batch_idx(batch_idx=batch_idx) if curr_opt_step % self.hparams.train_logging_activations_opt_steps == 0: self._log_activations(curr_opt_step=curr_opt_step) # --------------------------- # Global batch size ramp up # --------------------------- # # This logic needs to happen after the batch has been processed and results # logged, but before the model is saved for resume, so that the updated ramup up # variables will have the correct values on resume gbs_running.global_seen_samples += self.accelerator.num_processes * self.hparams.batch_size_per_gpu if ( self.hparams.global_batch_size_ramp_up.start is not None and self.hparams.global_batch_size_ramp_up.finish > gbs_running.global_batch_size_current and gbs_running.global_seen_samples >= gbs_running.next_goal_samples ): gbs_running.next_goal_samples += self.hparams.global_batch_size_ramp_up.samples gbs_running.global_batch_size_current += self.hparams.global_batch_size_ramp_up.increment gbs_running.grad_acc_size_current = int( gbs_running.global_batch_size_current / (self.hparams.batch_size_per_gpu * self.accelerator.num_processes) ) self.update_gas_and_gbs( gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current ) # -------------------- # Check if the training is over and if so may be save the model before validation # -------------------- finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model( curr_opt_step, curr_epoch, gbs_running, max_num_updates, train_logs, opt_step_is_saved, ) if finished_training: break if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: time_deltas["post_fwd"] = timer2.delta() time_deltas["iteration"] = timer2.elapsed() # -------------------- # Validation loop # -------------------- if ( self.config.hparams.do_validation and not eval_is_done and curr_opt_step != 0 and curr_opt_step % self.hparams.val_logging_opt_steps == 0 ): if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: timer3 = Timer() timer3.start() gc.collect() if self.accelerator.is_main_process and self.hparams.wandb_enable: wandb.unwatch(self.dummy_module) if self.activation_tracker: self.activation_tracker.is_eval() logger.info("** Starting validation **") ( val_steps, val_per_token_loss_acc, val_num_images, val_num_tokens, val_image_to_text_ratio, val_num_padding, ) = self._do_validation(progress, curr_opt_step) # -------------------- # Log validation infos # -------------------- self._log_validation( val_steps, curr_opt_step, val_per_token_loss_acc, val_num_images, val_num_tokens, val_image_to_text_ratio, val_num_padding, ) eval_is_done = True logger.info("** Finished validation **") if self.accelerator.is_main_process and self.hparams.wandb_enable: wandb.watch( self.dummy_module, log="all", log_freq=self.hparams.wandb_log_freq * self.hparams.grad_acc_size, idx=0, ) if self.activation_tracker: self.activation_tracker.is_train() gc.collect() self.accelerator.wait_for_everyone() if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: print(f"[TIME] Validation: {format_secs_to_time(timer3.stop())}") # restart timer from zero to avoid accounting for validation timer = DeviceAgnosticTimer() timer.start() if self.hparams.timing_break_down: self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: # Finalize print( f'[TIME] Iteration {format_secs_to_sec_fractions(time_deltas["iteration"])}:' f' {format_secs_to_sec_fractions(time_deltas["dl"]):>6} dl |' f' {format_secs_to_sec_fractions(time_deltas["between_dl_fwd_bwd"]):>6} between' f' dl/fwd-bwd | {format_secs_to_sec_fractions(time_deltas["fwd-bwd-step"]):>6} fwd/bwd' f' | {format_secs_to_sec_fractions(time_deltas["post_fwd"]):>6} post' ) # restart for __iter__ timer2.stop() timer2.start() if maybe_torch_profile_scheduler is not None and self.accelerator.is_main_process: maybe_torch_profile_scheduler.step() if not finished_training: curr_epoch += 1 train_logs = self._end_of_epoch_reset_train_logs(train_logs) self.train_loader.reset_state() if curr_epoch == max_num_epochs: self._save( train_logs, curr_opt_step, curr_epoch, gbs_running, ) finished_training = True logger.info("** Maximum number of epochs has been reached **") break if self.hparams.wandb_enable: self.accelerator.end_training() return train_logs def _get_system_metrics_logs(self, memory_value_max): return {"memory_max_over_nodes": memory_value_max}