def __init__()

in vision/m4/training/trainer.py [0:0]


    def __init__(self, accelerator, vl_model, tokenizer, train_loader, val_loader, config):
        # Initialize params
        self.config: Parameters = config
        self.optim_param: OptimizerParams = config.optim_param
        self.hparams: Hparams = config.hparams
        self.resume_param: ResumeParams = config.resume_param
        self.data_param: DataParams = config.data_param

        # Initialize last step directory
        self.last_opt_step_dir = ""

        # Initialize the model
        self.vl_model = vl_model

        # Gradient checkpointing
        if self.hparams.gradient_checkpointing:
            self.vl_model.gradient_checkpointing_enable()

        # Debug
        if accelerator.is_main_process and self.hparams.train_logging_activations:
            self.activation_tracker = ActivationTracker(self.vl_model)
        else:
            self.activation_tracker = None

        # Initialize tokenizer
        self.tokenizer = tokenizer
        self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)

        # Initialize accelerator
        self.accelerator = accelerator

        # Initialize loaders
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Checks
        self._compatibility_checks()

        # Initialize everything related to distributed training
        self._configure_optimizer_and_scheduler()

        # Prepare and/or register model, optimizer, dataloaders and scheduler
        self._prepare_register()

        # now that we have num_processes, figure out batch_size-related variables
        self.setup_batch_size_related_configs()

        # Compute useful variables
        self.optim_param.opt_batch_size = self.hparams.global_batch_size

        if self.hparams.max_num_opt_steps is None and self.hparams.max_num_epochs is None:
            if hasattr(self.train_loader, "__len__") and self.hparams.global_batch_size_ramp_up.start is not None:
                raise ValueError("Currently global batch size ramp up doesn't work with MappedDataset")

            try:
                self.hparams.max_num_opt_steps = int(len(self.train_loader) // self.hparams.grad_acc_size)
            except TypeError:
                raise ValueError("max_num_opt_steps or max_num_epochs must be defined if you use IterableDataset")
        # self._set_model_tflops_per_batch_per_gpu()

        # Init trackers
        self._init_trackers()

        # Handle jz timing and memory
        self.jz_training_time_over = [False]
        self.memory_explosion = False

        # Stopping on demand
        self.kill_switch_activated = False

        # Sigterm signal listener
        self.sigterm_signal_received = False
        self.sigterm_listener = SigtermListener()

        sizes = defaultdict(int)
        trainable_params = []
        numel_fn = lambda p: p.ds_numel if is_deepspeed_zero_init_enabled() else p.numel()  # noqa
        for name, param in self.accelerator.unwrap_model(self.vl_model).named_parameters():
            numel = numel_fn(param)
            sizes["total"] += numel
            sizes["total_lora"] += numel if "lora_" in name else 0
            if "vision_model" in name:
                sizes["vision_model"] += numel
                sizes["vision_model_lora"] += numel if "lora_" in name else 0
            if "perceiver_resampler" in name:
                sizes["perceiver_resampler"] += numel
            if "modality_projection" in name:
                sizes["modality_projection"] += numel
            if param.requires_grad:
                sizes["trainable"] += numel
                sizes["trainable_lora"] += numel if "lora_" in name else 0
                trainable_params += [name]

        if self.accelerator.is_main_process:
            logger.info(f"""