in vision/m4/training/trainer.py [0:0]
def __init__(self, accelerator, vl_model, tokenizer, train_loader, val_loader, config):
# Initialize params
self.config: Parameters = config
self.optim_param: OptimizerParams = config.optim_param
self.hparams: Hparams = config.hparams
self.resume_param: ResumeParams = config.resume_param
self.data_param: DataParams = config.data_param
# Initialize last step directory
self.last_opt_step_dir = ""
# Initialize the model
self.vl_model = vl_model
# Gradient checkpointing
if self.hparams.gradient_checkpointing:
self.vl_model.gradient_checkpointing_enable()
# Debug
if accelerator.is_main_process and self.hparams.train_logging_activations:
self.activation_tracker = ActivationTracker(self.vl_model)
else:
self.activation_tracker = None
# Initialize tokenizer
self.tokenizer = tokenizer
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
# Initialize accelerator
self.accelerator = accelerator
# Initialize loaders
self.train_loader = train_loader
self.val_loader = val_loader
# Checks
self._compatibility_checks()
# Initialize everything related to distributed training
self._configure_optimizer_and_scheduler()
# Prepare and/or register model, optimizer, dataloaders and scheduler
self._prepare_register()
# now that we have num_processes, figure out batch_size-related variables
self.setup_batch_size_related_configs()
# Compute useful variables
self.optim_param.opt_batch_size = self.hparams.global_batch_size
if self.hparams.max_num_opt_steps is None and self.hparams.max_num_epochs is None:
if hasattr(self.train_loader, "__len__") and self.hparams.global_batch_size_ramp_up.start is not None:
raise ValueError("Currently global batch size ramp up doesn't work with MappedDataset")
try:
self.hparams.max_num_opt_steps = int(len(self.train_loader) // self.hparams.grad_acc_size)
except TypeError:
raise ValueError("max_num_opt_steps or max_num_epochs must be defined if you use IterableDataset")
# self._set_model_tflops_per_batch_per_gpu()
# Init trackers
self._init_trackers()
# Handle jz timing and memory
self.jz_training_time_over = [False]
self.memory_explosion = False
# Stopping on demand
self.kill_switch_activated = False
# Sigterm signal listener
self.sigterm_signal_received = False
self.sigterm_listener = SigtermListener()
sizes = defaultdict(int)
trainable_params = []
numel_fn = lambda p: p.ds_numel if is_deepspeed_zero_init_enabled() else p.numel() # noqa
for name, param in self.accelerator.unwrap_model(self.vl_model).named_parameters():
numel = numel_fn(param)
sizes["total"] += numel
sizes["total_lora"] += numel if "lora_" in name else 0
if "vision_model" in name:
sizes["vision_model"] += numel
sizes["vision_model_lora"] += numel if "lora_" in name else 0
if "perceiver_resampler" in name:
sizes["perceiver_resampler"] += numel
if "modality_projection" in name:
sizes["modality_projection"] += numel
if param.requires_grad:
sizes["trainable"] += numel
sizes["trainable_lora"] += numel if "lora_" in name else 0
trainable_params += [name]
if self.accelerator.is_main_process:
logger.info(f"""