def train()

in vision/m4/training/trainer.py [0:0]


    def train(self, maybe_torch_profile_scheduler=None):
        # timing_break_down = self.accelerator.is_main_process and self.hparams.timing_break_down

        if self.accelerator.is_main_process:
            logger.info(f"** Global main process pid={os.getpid()} **")
        elif self.accelerator.is_local_main_process:
            logger.info(f"** Local main process pid={os.getpid()} **")
        # --------------------
        # Set-up everything needed for training
        # --------------------
        (
            progress_columns,
            train_logs,
            max_num_epochs,
            max_num_updates,
            curr_opt_step,
            curr_epoch,
            opt_step_is_saved,
            eval_is_done,
            gbs_running,
        ) = self._set_up_training()

        # --------------------
        # Training loop
        # --------------------
        self.vl_model.train()
        pynvml_handle = pynmvl_handle(self.accelerator)
        with Progress(*progress_columns, refresh_per_second=5, disable=True) as progress:
            progress_bar = progress.add_task(
                "[red]Training", disable=not self.accelerator.is_main_process, total=max_num_updates, visible=False
            )
            train_task = progress.tasks[-1]
            progress.update(progress_bar, advance=curr_opt_step)
            finished_training = False
            training_logged = True
            timer = DeviceAgnosticTimer()
            timer.start()

            if self.hparams.timing_break_down:
                self.accelerator.wait_for_everyone()
                if self.accelerator.is_main_process:
                    timer2 = Timer()
                    time_deltas = {}

            while not finished_training:
                self.train_loader.set_epoch(curr_epoch)

                # Handle resume based on `realtime_processing`
                if self.hparams.resume_run:
                    if not self.data_param.realtime_processing:
                        # When preparing, `accelerate` has a bug that overwrites the top-level `sampler`... but this is aesthetic!
                        #   The "actual" sampler that the DataLoader uses is the one that's tucked inside the `batch_sampler`
                        #   attribute when "single-process" (world_size = 1), or the `batch_sampler.batch_sampler` when
                        #   "multi-process" (world_size > 1) -- both of which point to our specialized ResumableSampler!
                        #
                        #   This is pretty annoying and nuanced; should PR into `accelerate` to fix this...
                        logger.warning("NOT realtime processing has not been extensively tested yet")
                        if self.accelerator.num_processes == 1:
                            # TODO; Fix this

                            self.train_loader.batch_sampler.sampler.set_state(self.train_loader.get_resume_state(0))

                        else:
                            # TODO :: This is actually broken and not respected by `accelerate` - fails!
                            # self.train_loader.batch_sampler.batch_sampler.sampler.set_state(
                            #     self.resumable_state.get_resume_state()
                            # )
                            raise NotImplementedError("Map Dataset Resume w/ DDP not yet implemented!")
                    else:
                        self.train_loader.load_resume_states()

                if self.hparams.timing_break_down:
                    self.accelerator.wait_for_everyone()
                    if self.accelerator.is_main_process:
                        timer2.start()

                for curr_idx, (dataset_idx, dataset_name, dataset_state, batch) in enumerate(self.train_loader):
                    # --------------------
                    # Check if the training is over and if so may be save the model before training batch
                    # --------------------
                    if self.hparams.timing_break_down:
                        self.accelerator.wait_for_everyone()
                        if self.accelerator.is_main_process:
                            time_deltas["dl"] = timer2.delta()

                    finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model(
                        curr_opt_step,
                        curr_epoch,
                        gbs_running,
                        max_num_updates,
                        train_logs,
                        opt_step_is_saved,
                    )
                    if finished_training:
                        break

                    # --------------------
                    # Activate/deactivate hooks for logging activations or not
                    # --------------------
                    # We are logging everything at `curr_opt_step`, but `curr_opt_step` is incremented a few lines later, so activating
                    # the activation tracking hooks based on `curr_opt_step + 1`. See `_log_activations` for more details.
                    if self.activation_tracker:
                        if (curr_opt_step + 1) % self.hparams.train_logging_activations_opt_steps == 0 and (
                            curr_idx + 1
                        ) % self.hparams.grad_acc_size == 0:
                            self.activation_tracker.activate_hooks()
                        else:
                            self.activation_tracker.deactivate_hooks()

                    if (
                        self.hparams.save_batch_max_idx is not None
                        and self.hparams.save_batch_min_idx is not None
                        and curr_idx <= self.hparams.save_batch_max_idx
                        and curr_idx >= self.hparams.save_batch_min_idx
                    ):
                        self._save_batch(batch, curr_idx)

                    # right before fwd-bwd-step
                    if self.hparams.timing_break_down:
                        self.accelerator.wait_for_everyone()
                        if self.accelerator.is_main_process:
                            time_deltas["between_dl_fwd_bwd"] = timer2.delta()
                    total_energy_start = pynvml_get_total_energy_in_joules(pynvml_handle)

                    with self.accelerator.accumulate(self.vl_model):
                        (
                            per_token_loss,
                            z_loss,
                            num_images,
                            num_image_tokens,
                            num_tokens,
                            image_to_text_ratio,
                            num_padding,
                            pixel_values_sum,
                            tflops_per_batch_per_gpu,
                        ) = self._do_batch(
                            batch,
                            curr_opt_step=curr_opt_step,
                            dataset_name=dataset_name,
                            dataset_idx=dataset_idx,
                            validation=False,
                        )

                    # right after fwd-bwd-step
                    if self.hparams.timing_break_down:
                        self.accelerator.wait_for_everyone()
                        if self.accelerator.is_main_process:
                            time_deltas["fwd-bwd-step"] = timer2.delta()

                    fwd_bwd_time = timer.stop()
                    fwd_bwd_time = torch.tensor(fwd_bwd_time, device=self.accelerator.device)

                    total_energy_delta_per_gpu = torch.tensor(
                        pynvml_get_total_energy_in_joules(pynvml_handle) - total_energy_start,
                        device=self.accelerator.device,
                    )

                    if (curr_idx + 1) % self.hparams.grad_acc_size == 0:
                        curr_opt_step += 1
                        opt_step_is_saved, eval_is_done, training_logged = False, False, False
                        progress.update(progress_bar, advance=1)

                    # --------------------
                    # Update logs
                    # --------------------
                    train_logs = self._update_logs(
                        curr_opt_step,
                        curr_epoch,
                        gbs_running.global_batch_size_current,
                        train_logs,
                        per_token_loss,
                        z_loss,
                        num_tokens,
                        num_images,
                        num_image_tokens,
                        image_to_text_ratio,
                        num_padding,
                        fwd_bwd_time,
                        pixel_values_sum,
                        tflops_per_batch_per_gpu,
                        total_energy_delta_per_gpu,
                        dataset_name,
                        self.hparams.train_logging_per_dataset_suffix,
                    )
                    timer = DeviceAgnosticTimer()
                    timer.start()
                    # --------------------
                    # Update datasets states
                    # --------------------
                    self._update_datasets_states(dataset_idx, dataset_state)

                    # --------------------
                    # Log training infos
                    # --------------------
                    if curr_opt_step % self.hparams.train_logging_opt_steps == 0 and not training_logged:
                        train_logs = self._log_training(curr_opt_step, train_task, train_logs)
                        training_logged = True

                    # --------------------
                    # Log activations
                    # --------------------
                    if self.activation_tracker:
                        batch_idx = train_logs["num_batches"]["all"]
                        self.activation_tracker.fill_in_batch_idx(batch_idx=batch_idx)
                        if curr_opt_step % self.hparams.train_logging_activations_opt_steps == 0:
                            self._log_activations(curr_opt_step=curr_opt_step)

                    # ---------------------------
                    # Global batch size ramp up
                    # ---------------------------
                    #
                    # This logic needs to happen after the batch has been processed and results
                    # logged, but before the model is saved for resume, so that the updated ramup up
                    # variables will have the correct values on resume
                    gbs_running.global_seen_samples += self.accelerator.num_processes * self.hparams.batch_size_per_gpu
                    if (
                        self.hparams.global_batch_size_ramp_up.start is not None
                        and self.hparams.global_batch_size_ramp_up.finish > gbs_running.global_batch_size_current
                        and gbs_running.global_seen_samples >= gbs_running.next_goal_samples
                    ):
                        gbs_running.next_goal_samples += self.hparams.global_batch_size_ramp_up.samples

                        gbs_running.global_batch_size_current += self.hparams.global_batch_size_ramp_up.increment
                        gbs_running.grad_acc_size_current = int(
                            gbs_running.global_batch_size_current
                            / (self.hparams.batch_size_per_gpu * self.accelerator.num_processes)
                        )

                        self.update_gas_and_gbs(
                            gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current
                        )

                    # --------------------
                    # Check if the training is over and if so may be save the model before validation
                    # --------------------
                    finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model(
                        curr_opt_step,
                        curr_epoch,
                        gbs_running,
                        max_num_updates,
                        train_logs,
                        opt_step_is_saved,
                    )

                    if finished_training:
                        break

                    if self.hparams.timing_break_down:
                        self.accelerator.wait_for_everyone()
                        if self.accelerator.is_main_process:
                            time_deltas["post_fwd"] = timer2.delta()
                            time_deltas["iteration"] = timer2.elapsed()

                    # --------------------
                    # Validation loop
                    # --------------------
                    if (
                        self.config.hparams.do_validation
                        and not eval_is_done
                        and curr_opt_step != 0
                        and curr_opt_step % self.hparams.val_logging_opt_steps == 0
                    ):
                        if self.hparams.timing_break_down:
                            self.accelerator.wait_for_everyone()
                            if self.accelerator.is_main_process:
                                timer3 = Timer()
                                timer3.start()

                        gc.collect()
                        if self.accelerator.is_main_process and self.hparams.wandb_enable:
                            wandb.unwatch(self.dummy_module)
                        if self.activation_tracker:
                            self.activation_tracker.is_eval()
                        logger.info("** Starting validation **")
                        (
                            val_steps,
                            val_per_token_loss_acc,
                            val_num_images,
                            val_num_tokens,
                            val_image_to_text_ratio,
                            val_num_padding,
                        ) = self._do_validation(progress, curr_opt_step)

                        # --------------------
                        # Log validation infos
                        # --------------------
                        self._log_validation(
                            val_steps,
                            curr_opt_step,
                            val_per_token_loss_acc,
                            val_num_images,
                            val_num_tokens,
                            val_image_to_text_ratio,
                            val_num_padding,
                        )
                        eval_is_done = True
                        logger.info("** Finished validation **")
                        if self.accelerator.is_main_process and self.hparams.wandb_enable:
                            wandb.watch(
                                self.dummy_module,
                                log="all",
                                log_freq=self.hparams.wandb_log_freq * self.hparams.grad_acc_size,
                                idx=0,
                            )
                        if self.activation_tracker:
                            self.activation_tracker.is_train()
                        gc.collect()
                        self.accelerator.wait_for_everyone()

                        if self.hparams.timing_break_down:
                            self.accelerator.wait_for_everyone()
                            if self.accelerator.is_main_process:
                                print(f"[TIME] Validation: {format_secs_to_time(timer3.stop())}")

                        # restart timer from zero to avoid accounting for validation
                        timer = DeviceAgnosticTimer()
                        timer.start()

                    if self.hparams.timing_break_down:
                        self.accelerator.wait_for_everyone()
                        if self.accelerator.is_main_process:
                            # Finalize
                            print(
                                f'[TIME] Iteration {format_secs_to_sec_fractions(time_deltas["iteration"])}:'
                                f' {format_secs_to_sec_fractions(time_deltas["dl"]):>6} dl |'
                                f' {format_secs_to_sec_fractions(time_deltas["between_dl_fwd_bwd"]):>6} between'
                                f' dl/fwd-bwd | {format_secs_to_sec_fractions(time_deltas["fwd-bwd-step"]):>6} fwd/bwd'
                                f' | {format_secs_to_sec_fractions(time_deltas["post_fwd"]):>6} post'
                            )

                            # restart for __iter__
                            timer2.stop()
                            timer2.start()

                    if maybe_torch_profile_scheduler is not None and self.accelerator.is_main_process:
                        maybe_torch_profile_scheduler.step()

                if not finished_training:
                    curr_epoch += 1
                    train_logs = self._end_of_epoch_reset_train_logs(train_logs)

                    self.train_loader.reset_state()

                if curr_epoch == max_num_epochs:
                    self._save(
                        train_logs,
                        curr_opt_step,
                        curr_epoch,
                        gbs_running,
                    )
                    finished_training = True
                    logger.info("** Maximum number of epochs has been reached **")
                    break

            if self.hparams.wandb_enable:
                self.accelerator.end_training()

        return train_logs