def _set_up_training()

in vision/m4/training/trainer.py [0:0]


    def _set_up_training(self):
        """
        Prepare variables for trainings.
        """

        if self.hparams.resume_run:
            # 1. resume
            train_logs = self.resume_param.train_logs

            curr_opt_step = self.resume_param.resume_opt_step
            curr_epoch = self.resume_param.resume_epoch
            gbs_running = self.resume_param.gbs_running

            # This check is necessary because the info is saved as json in the checkpoint
            # and when it is loaded back it is converted to a normal dictionary which can
            # fail downstream in case one of the dataset keys were missing in the saved info
            train_logs = self._check_default_dict_in_train_logs(train_logs)
            self.train_loader.load_state(self.resume_param.opt_step_dir / "resumable_states")

            if self.hparams.load_optimizer_states:
                self.accelerator.load_state(self.resume_param.accelerator_state_dir)
            else:
                # don't load the optimizer states and start with a fresh optimizer
                self.accelerator.load_state(self.resume_param.accelerator_state_dir, load_optimizer_states=False)
                validate_optim_states_are_reset(self)

            self.accelerator.wait_for_everyone()

            opt_step_is_saved = True
            eval_is_done = True

        else:
            # 2. non-resume (first run)
            train_logs = self._reset_train_logs(None)
            curr_opt_step = 0
            curr_epoch = 0
            opt_step_is_saved = False
            eval_is_done = False

            gbs_running = GlobalBatchSizeRampUpRunningParams(
                global_seen_samples=0,
                global_batch_size_current=self.hparams.global_batch_size,
                next_goal_samples=self.hparams.global_batch_size_ramp_up.samples,
                grad_acc_size_current=self.hparams.grad_acc_size,
            )

            self.train_loader.reset_state()

            # rng = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(self.hparams.seed)))
            # self.main_rng_seed = rng.get_state()
            # self.main_rng_seed = rng.RandomState.get_state()

        self.update_gas_and_gbs(gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current)

        max_num_epochs = self.hparams.max_num_epochs
        try:
            num_batches = int(len(self.train_loader) // self.hparams.grad_acc_size)
            max_num_updates = min(self.hparams.max_num_opt_steps, num_batches)
            if max_num_epochs is not None:
                logger.info(
                    "** Setting `max_num_updates` to `max_num_epochs * num_batches` since `max_num_epochs` "
                    "was specified and `max_num_epochs * num_batches` is smaller than `max_num_updates`. **"
                )
                max_num_updates = min(max_num_updates, max_num_epochs * num_batches)
        except TypeError:
            # For iterable datasets len(dataset) is not defined
            max_num_updates = self.hparams.max_num_opt_steps

        if self.hparams.max_num_opt_steps_this_run is not None:
            self.max_num_updates_this_run = min(
                max_num_updates, curr_opt_step + self.hparams.max_num_opt_steps_this_run
            )
        else:
            self.max_num_updates_this_run = max_num_updates

        progress_columns = (
            "[progress.description]{task.description}",
            BarColumn(),
            TaskProgressColumn(),
            "Time Elapsed:",
            TimeElapsedColumn(),
            "Steps Completed",
            MofNCompleteColumn(),
        )
        return (
            progress_columns,
            train_logs,
            max_num_epochs,
            max_num_updates,
            curr_opt_step,
            curr_epoch,
            opt_step_is_saved,
            eval_is_done,
            gbs_running,
        )