in vision/m4/training/trainer.py [0:0]
def _set_up_training(self):
"""
Prepare variables for trainings.
"""
if self.hparams.resume_run:
# 1. resume
train_logs = self.resume_param.train_logs
curr_opt_step = self.resume_param.resume_opt_step
curr_epoch = self.resume_param.resume_epoch
gbs_running = self.resume_param.gbs_running
# This check is necessary because the info is saved as json in the checkpoint
# and when it is loaded back it is converted to a normal dictionary which can
# fail downstream in case one of the dataset keys were missing in the saved info
train_logs = self._check_default_dict_in_train_logs(train_logs)
self.train_loader.load_state(self.resume_param.opt_step_dir / "resumable_states")
if self.hparams.load_optimizer_states:
self.accelerator.load_state(self.resume_param.accelerator_state_dir)
else:
# don't load the optimizer states and start with a fresh optimizer
self.accelerator.load_state(self.resume_param.accelerator_state_dir, load_optimizer_states=False)
validate_optim_states_are_reset(self)
self.accelerator.wait_for_everyone()
opt_step_is_saved = True
eval_is_done = True
else:
# 2. non-resume (first run)
train_logs = self._reset_train_logs(None)
curr_opt_step = 0
curr_epoch = 0
opt_step_is_saved = False
eval_is_done = False
gbs_running = GlobalBatchSizeRampUpRunningParams(
global_seen_samples=0,
global_batch_size_current=self.hparams.global_batch_size,
next_goal_samples=self.hparams.global_batch_size_ramp_up.samples,
grad_acc_size_current=self.hparams.grad_acc_size,
)
self.train_loader.reset_state()
# rng = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(self.hparams.seed)))
# self.main_rng_seed = rng.get_state()
# self.main_rng_seed = rng.RandomState.get_state()
self.update_gas_and_gbs(gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current)
max_num_epochs = self.hparams.max_num_epochs
try:
num_batches = int(len(self.train_loader) // self.hparams.grad_acc_size)
max_num_updates = min(self.hparams.max_num_opt_steps, num_batches)
if max_num_epochs is not None:
logger.info(
"** Setting `max_num_updates` to `max_num_epochs * num_batches` since `max_num_epochs` "
"was specified and `max_num_epochs * num_batches` is smaller than `max_num_updates`. **"
)
max_num_updates = min(max_num_updates, max_num_epochs * num_batches)
except TypeError:
# For iterable datasets len(dataset) is not defined
max_num_updates = self.hparams.max_num_opt_steps
if self.hparams.max_num_opt_steps_this_run is not None:
self.max_num_updates_this_run = min(
max_num_updates, curr_opt_step + self.hparams.max_num_opt_steps_this_run
)
else:
self.max_num_updates_this_run = max_num_updates
progress_columns = (
"[progress.description]{task.description}",
BarColumn(),
TaskProgressColumn(),
"Time Elapsed:",
TimeElapsedColumn(),
"Steps Completed",
MofNCompleteColumn(),
)
return (
progress_columns,
train_logs,
max_num_epochs,
max_num_updates,
curr_opt_step,
curr_epoch,
opt_step_is_saved,
eval_is_done,
gbs_running,
)