vision/m4/training/main.py (294 lines of code) (raw):

import json import logging import os import sys import time from datetime import timedelta import accelerate import datasets import torch import transformers from accelerate import Accelerator, InitProcessGroupKwargs from accelerate.state import AcceleratorState from peft import LoraConfig, PeftConfig from torch.profiler.profiler import ProfilerActivity, profile, tensorboard_trace_handler from transformers import AddedToken # AddedToken is needed for the eval of the tokenizer params # noqa: F401 from transformers import AutoTokenizer # noqa: F401 from transformers.utils import ContextManagers, is_torch_tf32_available import m4 from m4.training.config import get_config from m4.training.dataset import get_dataloaders from m4.training.setup_language_model import model_name_to_classes from m4.training.trainer import Trainer from m4.training.types import DatasetNames from m4.training.utils import VisionEncoderTypes, accelerate_torch_dtype, build_image_transform, get_tokenizer from m4.utils.progress import M4_DISABLE_RICH from m4.utils.training.timer import Timer, format_secs_to_time logging.basicConfig( level=logging.INFO, format=" - %(process)d - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", ) if not M4_DISABLE_RICH: from rich.logging import RichHandler logging.getLogger("").addHandler(RichHandler(level=logging.INFO)) logger = logging.getLogger(__name__) if __name__ == "__main__": START_TIME = time.time() # this gives a very nice speed boost on Ampere if is_torch_tf32_available(): torch.backends.cuda.matmul.allow_tf32 = True config = get_config() # @TEMPORARY GATE -- if resuming, `realtime_processing` must be True if config.hparams.resume_run and not config.data_param.realtime_processing: raise NotImplementedError("Instant resume functionality not yet supported for non-iterable datasets!") # Initialize accelerator kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(seconds=config.hparams.timeout))] accelerator = Accelerator( log_with="all", rng_types=["torch", "cuda", "generator"], gradient_accumulation_steps=config.hparams.grad_acc_size, kwargs_handlers=kwargs_handlers, ) if config.hparams.timing_break_down: accelerator.wait_for_everyone() if accelerator.is_main_process: timer1 = Timer() time_deltas = {} timer1.start() # logger behavior - this and sub-systems main_process_log_level = m4.utils.logging.get_log_levels_dict()[os.getenv("M4_VERBOSITY", "info")] log_level = main_process_log_level if accelerator.is_main_process else logging.ERROR m4.utils.logging.set_verbosity(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) if AcceleratorState().deepspeed_plugin is not None: from deepspeed.utils import logger as ds_logger ds_logger.setLevel(log_level) if config.hparams.kill_switch_path is not None and config.hparams.kill_switch_path.exists(): logger.info("** Kill switch activated. Exiting the training before it even starts. **") sys.exit() logger.info(f"** The job is running with the following arguments: **\n{config}\n **** ") accelerate.utils.set_seed(config.hparams.seed) # make dir if needed config.hparams.save_dir.mkdir(parents=True, exist_ok=True) # When fine_tuning, often the model name does not contain llama/idefics so we try to get this info in the config file if it exists if config.hparams.is_fine_tuning and os.path.exists(f"{config.hparams.model_name}/config.json"): with open(f"{config.hparams.model_name}/config.json", "r") as f: model_config = json.loads(f.read()) model_type = model_config["model_type"] config_class, model_class = model_name_to_classes(model_type) else: config_class, model_class = model_name_to_classes(config.hparams.model_name) # we want the target dtype in order to load the model is the most optimal way model_kwargs = dict(torch_dtype=accelerate_torch_dtype()) # Case when resuming run. For both pretraining and fine tuning if config.hparams.resume_run: logger.info("Using saved model") vl_model = model_class.from_pretrained( config.resume_param.model_file, config=config.resume_param.model_config_file, is_resume=True, **model_kwargs, ) if config.hparams.use_lora: peft_config = PeftConfig.from_pretrained(config.resume_param.lora_file) vl_model.add_adapter(peft_config) vl_model.enable_adapters() logger.info("Resuming training with trained adapter") # Case when starting fine tuning elif config.hparams.is_fine_tuning: # Additionnal vocabulary can be 3 instead of 2 for finetuning to integrate the <end_of_utterance> token # However, if we want to keep training a model from the hub without <end_of_utterance> token. This works as well additional_vocab_size = 39 + len(eval(config.hparams.tokenizer_add_tokens)) print(f"additional_vocab_size: {additional_vocab_size}") logger.warning( "This is a fine tuning procedure, so the model parameters are inherited from the base model EXCEPT those" " regarding freezing and additional vocabulary. Finetuning with an additional vocab size of" f" {additional_vocab_size}" ) vl_model = model_class.from_pretrained( config.hparams.model_name, is_resume=False, new_model=False, trust_remote_code=True, freeze_lm_head=config.hparams.model_config["freeze_lm_head"], freeze_text_layers=config.hparams.model_config["freeze_text_layers"], freeze_vision_layers=config.hparams.model_config["freeze_vision_layers"], additional_vocab_size=additional_vocab_size, **model_kwargs, ) if config.hparams.use_lora and config.hparams.lora_name is not None: vl_model.load_adapter(config.hparams.lora_name) vl_model.enable_adapters() logger.info("Loaded trained adapter") # Standard case for starting a pretraining else: logger.info("Using newly initialized model") additional_special_tokens = eval(config.hparams.tokenizer_params).get("additional_special_tokens", []) vl_config = config_class.from_pretrained( config.hparams.model_name, revision=config.hparams.revision, new_model=True, additional_vocab_size=len(eval(config.hparams.tokenizer_add_tokens)), **config.hparams.model_config, ) vl_model = model_class.from_pretrained_models(config.hparams.model_name, config=vl_config, **model_kwargs) # If we want to use_lora and are starting a pretraining, or if we want to use a new lora for fine tuning, create the config and add the adapter. if ( config.hparams.use_lora and not config.hparams.resume_run and not (config.hparams.is_fine_tuning and config.hparams.lora_name is not None) ): # Identify the target_modules with the patterns_to_loraify given in config. target_modules = [] for name, param in vl_model.named_parameters(): patterns_to_loraify_in_name = [ all(pattern in name for pattern in pattern_list) for pattern_list in config.hparams.patterns_to_loraify ] if any(patterns_to_loraify_in_name): # Take off the suffixes ".weight" or ".bias" target_module_name = ".".join(name.split(".")[:-1]) target_modules.append(target_module_name) peft_config = LoraConfig( target_modules=target_modules, **config.hparams.lora_config, ) vl_model.add_adapter(peft_config) vl_model.enable_adapters() logger.info("Loaded new adapter") # If the model has a lora, we want to unfreeze some layers which got frozen when loading the lora if config.hparams.use_lora: for name, param in vl_model.named_parameters(): patterns_to_unfreeze_in_name = [ all(pattern in name for pattern in pattern_list) for pattern_list in config.hparams.patterns_to_unfreeze ] if any(patterns_to_unfreeze_in_name): param.requires_grad_(True) # Get the seq_len for a single image as it is necesssary for packing single_image_seq_len = ( vl_model.config.perceiver_config.resampler_n_latents if vl_model.config.use_resampler else int(((vl_model.config.vision_config.image_size // vl_model.config.vision_config.patch_size) ** 2) / (vl_model.config.pixel_shuffle_factor**2)) # else (vl_model.config.vision_config.image_size // vl_model.config.vision_config.patch_size) ** 2 ) if config.hparams.timing_break_down: accelerator.wait_for_everyone() if accelerator.is_main_process: time_deltas["model_load"] = timer1.delta() tokenizer = get_tokenizer( tokenizer_name=config.hparams.tokenizer_name, tokenizer_add_tokens=config.hparams.tokenizer_add_tokens, tokenizer_add_special_tokens=config.hparams.tokenizer_add_special_tokens, tokenizer_params=config.hparams.tokenizer_params, additional_vocab_size=len(eval(config.hparams.tokenizer_add_tokens)), model_vocab_size=vl_model.config.vocab_size, is_fine_tuning=config.hparams.is_fine_tuning, ) tokenizer.pad_token_id = 2 if config.hparams.timing_break_down: accelerator.wait_for_everyone() if accelerator.is_main_process: time_deltas["tokenizer_load"] = timer1.delta() vision_model_name = vl_model.config.vision_config.vision_model_name vision_encoder_type = None for encoder_type in VisionEncoderTypes: if encoder_type.value in vision_model_name.lower(): vision_encoder_type = encoder_type break train_image_transforms = {} val_image_transforms = {} for dataset_name in DatasetNames: dataset_param = getattr(config.data_param, dataset_name.value) setattr(dataset_param, "vision_encoder_max_image_size", vl_model.config.vision_config.image_size) train_image_transform = build_image_transform( max_image_size=vl_model.config.vision_config.image_size, min_image_size=dataset_param.min_image_size, image_size=None, vision_encoder_type=vision_encoder_type, dataset_name=dataset_name, scale_up_max=dataset_param.scale_up_max, scale_up_frequency=dataset_param.scale_up_frequency, ) train_image_transforms[dataset_name.name.lower()] = train_image_transform val_image_transform = build_image_transform( max_image_size=vl_model.config.vision_config.image_size, min_image_size=dataset_param.min_image_size, image_size=None, eval=True, vision_encoder_type=vision_encoder_type, dataset_name=dataset_name, ) val_image_transforms[dataset_name.name.lower()] = val_image_transform # Initialize data loaders if accelerator.is_local_main_process: train_loader, val_loader = get_dataloaders( config, rank=accelerator.process_index, world_size=accelerator.num_processes, tokenizer=tokenizer, train_image_transforms=train_image_transforms, val_image_transforms=val_image_transforms, image_seq_len=single_image_seq_len, ) if config.hparams.loss_weights_per_dataset is not None: if config.hparams.grad_acc_size % train_loader.dataset.num_datasets != 0: raise ValueError( "grad_acc_size must be a multiple of num_datasets when accumulating the loss over datasets" ) if config.hparams.loss_weights_per_dataset is not None: if train_loader.dataset.num_datasets != len(config.hparams.loss_weights_per_dataset): raise ValueError( "num_datasets must equal length of loss_weights_per_dataset when accumulating the loss over" " datasets" ) accelerator.wait_for_everyone() # And then send it to the rest of them if not accelerator.is_local_main_process: train_loader, val_loader = get_dataloaders( config, rank=accelerator.process_index, world_size=accelerator.num_processes, tokenizer=tokenizer, train_image_transforms=train_image_transforms, val_image_transforms=val_image_transforms, image_seq_len=single_image_seq_len, ) accelerator.wait_for_everyone() if config.hparams.timing_break_down: accelerator.wait_for_everyone() if accelerator.is_main_process: time_deltas["dl_load"] = timer1.delta() # If the sole purpose of the job is to pre-process the dataset, exit here. if config.hparams.just_preprocess: logger.info("Preprocessing finished. Exiting the job.") sys.exit() # Get max_num_tokens try: config.hparams.max_num_tokens = len(train_loader.dataset) * config.data_param.max_seq_len except TypeError: # Can't have max_num_tokens because it is an IterableDataset config.hparams.max_num_tokens = -1 # Saving config after it has been auto-populated config.save_config_state() trainer = Trainer( accelerator=accelerator, vl_model=vl_model, tokenizer=tokenizer, train_loader=train_loader, val_loader=val_loader, config=config, ) if config.hparams.timing_break_down: accelerator.wait_for_everyone() if accelerator.is_main_process: time_deltas["trainer_load"] = timer1.delta() time_deltas["all_load"] = timer1.elapsed() timer1.stop() # Finalize print(f""" [TIME] Model: {format_secs_to_time(time_deltas["model_load"])} [TIME] Tokenizer: {format_secs_to_time(time_deltas["tokenizer_load"])} [TIME] DataLoader: {format_secs_to_time(time_deltas["dl_load"])} [TIME] Trainer: {format_secs_to_time(time_deltas["trainer_load"])} [TIME] Total load: {format_secs_to_time(time_deltas["all_load"])} """) maybe_torch_profile = [] if config.hparams.use_torch_profiler and accelerator.is_main_process: torch_profiler_export_path = config.hparams.save_dir / "torch_profiler" maybe_torch_profile_scheduler = torch.profiler.schedule( skip_first=10, wait=5, warmup=1, active=2, # repeat=2 ) maybe_torch_profile.append( profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), on_trace_ready=tensorboard_trace_handler(torch_profiler_export_path), record_shapes=False, profile_memory=True, with_stack=True, schedule=maybe_torch_profile_scheduler, ) ) with ContextManagers(maybe_torch_profile): train_logs = trainer.train(maybe_torch_profile[0] if len(maybe_torch_profile) == 1 else None) if accelerator.is_main_process: logger.info(f"Last step directory: {trainer.last_opt_step_dir}") logger.info(f"Training logs: {train_logs}") train_log_file = config.hparams.save_dir / "train_logs.json" with open(train_log_file, "w") as fh: json.dump(train_logs, fh) print(f"LOSS: {train_logs.get('per_token_loss', 0)}")