in vision/m4/training/trainer.py [0:0]
def train(self, maybe_torch_profile_scheduler=None):
# timing_break_down = self.accelerator.is_main_process and self.hparams.timing_break_down
if self.accelerator.is_main_process:
logger.info(f"** Global main process pid={os.getpid()} **")
elif self.accelerator.is_local_main_process:
logger.info(f"** Local main process pid={os.getpid()} **")
# --------------------
# Set-up everything needed for training
# --------------------
(
progress_columns,
train_logs,
max_num_epochs,
max_num_updates,
curr_opt_step,
curr_epoch,
opt_step_is_saved,
eval_is_done,
gbs_running,
) = self._set_up_training()
# --------------------
# Training loop
# --------------------
self.vl_model.train()
pynvml_handle = pynmvl_handle(self.accelerator)
with Progress(*progress_columns, refresh_per_second=5, disable=True) as progress:
progress_bar = progress.add_task(
"[red]Training", disable=not self.accelerator.is_main_process, total=max_num_updates, visible=False
)
train_task = progress.tasks[-1]
progress.update(progress_bar, advance=curr_opt_step)
finished_training = False
training_logged = True
timer = DeviceAgnosticTimer()
timer.start()
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
timer2 = Timer()
time_deltas = {}
while not finished_training:
self.train_loader.set_epoch(curr_epoch)
# Handle resume based on `realtime_processing`
if self.hparams.resume_run:
if not self.data_param.realtime_processing:
# When preparing, `accelerate` has a bug that overwrites the top-level `sampler`... but this is aesthetic!
# The "actual" sampler that the DataLoader uses is the one that's tucked inside the `batch_sampler`
# attribute when "single-process" (world_size = 1), or the `batch_sampler.batch_sampler` when
# "multi-process" (world_size > 1) -- both of which point to our specialized ResumableSampler!
#
# This is pretty annoying and nuanced; should PR into `accelerate` to fix this...
logger.warning("NOT realtime processing has not been extensively tested yet")
if self.accelerator.num_processes == 1:
# TODO; Fix this
self.train_loader.batch_sampler.sampler.set_state(self.train_loader.get_resume_state(0))
else:
# TODO :: This is actually broken and not respected by `accelerate` - fails!
# self.train_loader.batch_sampler.batch_sampler.sampler.set_state(
# self.resumable_state.get_resume_state()
# )
raise NotImplementedError("Map Dataset Resume w/ DDP not yet implemented!")
else:
self.train_loader.load_resume_states()
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
timer2.start()
for curr_idx, (dataset_idx, dataset_name, dataset_state, batch) in enumerate(self.train_loader):
# --------------------
# Check if the training is over and if so may be save the model before training batch
# --------------------
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
time_deltas["dl"] = timer2.delta()
finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model(
curr_opt_step,
curr_epoch,
gbs_running,
max_num_updates,
train_logs,
opt_step_is_saved,
)
if finished_training:
break
# --------------------
# Activate/deactivate hooks for logging activations or not
# --------------------
# We are logging everything at `curr_opt_step`, but `curr_opt_step` is incremented a few lines later, so activating
# the activation tracking hooks based on `curr_opt_step + 1`. See `_log_activations` for more details.
if self.activation_tracker:
if (curr_opt_step + 1) % self.hparams.train_logging_activations_opt_steps == 0 and (
curr_idx + 1
) % self.hparams.grad_acc_size == 0:
self.activation_tracker.activate_hooks()
else:
self.activation_tracker.deactivate_hooks()
if (
self.hparams.save_batch_max_idx is not None
and self.hparams.save_batch_min_idx is not None
and curr_idx <= self.hparams.save_batch_max_idx
and curr_idx >= self.hparams.save_batch_min_idx
):
self._save_batch(batch, curr_idx)
# right before fwd-bwd-step
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
time_deltas["between_dl_fwd_bwd"] = timer2.delta()
total_energy_start = pynvml_get_total_energy_in_joules(pynvml_handle)
with self.accelerator.accumulate(self.vl_model):
(
per_token_loss,
z_loss,
num_images,
num_image_tokens,
num_tokens,
image_to_text_ratio,
num_padding,
pixel_values_sum,
tflops_per_batch_per_gpu,
) = self._do_batch(
batch,
curr_opt_step=curr_opt_step,
dataset_name=dataset_name,
dataset_idx=dataset_idx,
validation=False,
)
# right after fwd-bwd-step
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
time_deltas["fwd-bwd-step"] = timer2.delta()
fwd_bwd_time = timer.stop()
fwd_bwd_time = torch.tensor(fwd_bwd_time, device=self.accelerator.device)
total_energy_delta_per_gpu = torch.tensor(
pynvml_get_total_energy_in_joules(pynvml_handle) - total_energy_start,
device=self.accelerator.device,
)
if (curr_idx + 1) % self.hparams.grad_acc_size == 0:
curr_opt_step += 1
opt_step_is_saved, eval_is_done, training_logged = False, False, False
progress.update(progress_bar, advance=1)
# --------------------
# Update logs
# --------------------
train_logs = self._update_logs(
curr_opt_step,
curr_epoch,
gbs_running.global_batch_size_current,
train_logs,
per_token_loss,
z_loss,
num_tokens,
num_images,
num_image_tokens,
image_to_text_ratio,
num_padding,
fwd_bwd_time,
pixel_values_sum,
tflops_per_batch_per_gpu,
total_energy_delta_per_gpu,
dataset_name,
self.hparams.train_logging_per_dataset_suffix,
)
timer = DeviceAgnosticTimer()
timer.start()
# --------------------
# Update datasets states
# --------------------
self._update_datasets_states(dataset_idx, dataset_state)
# --------------------
# Log training infos
# --------------------
if curr_opt_step % self.hparams.train_logging_opt_steps == 0 and not training_logged:
train_logs = self._log_training(curr_opt_step, train_task, train_logs)
training_logged = True
# --------------------
# Log activations
# --------------------
if self.activation_tracker:
batch_idx = train_logs["num_batches"]["all"]
self.activation_tracker.fill_in_batch_idx(batch_idx=batch_idx)
if curr_opt_step % self.hparams.train_logging_activations_opt_steps == 0:
self._log_activations(curr_opt_step=curr_opt_step)
# ---------------------------
# Global batch size ramp up
# ---------------------------
#
# This logic needs to happen after the batch has been processed and results
# logged, but before the model is saved for resume, so that the updated ramup up
# variables will have the correct values on resume
gbs_running.global_seen_samples += self.accelerator.num_processes * self.hparams.batch_size_per_gpu
if (
self.hparams.global_batch_size_ramp_up.start is not None
and self.hparams.global_batch_size_ramp_up.finish > gbs_running.global_batch_size_current
and gbs_running.global_seen_samples >= gbs_running.next_goal_samples
):
gbs_running.next_goal_samples += self.hparams.global_batch_size_ramp_up.samples
gbs_running.global_batch_size_current += self.hparams.global_batch_size_ramp_up.increment
gbs_running.grad_acc_size_current = int(
gbs_running.global_batch_size_current
/ (self.hparams.batch_size_per_gpu * self.accelerator.num_processes)
)
self.update_gas_and_gbs(
gbs_running.grad_acc_size_current, gbs_running.global_batch_size_current
)
# --------------------
# Check if the training is over and if so may be save the model before validation
# --------------------
finished_training, opt_step_is_saved = self._check_if_training_is_over_and_maybe_save_model(
curr_opt_step,
curr_epoch,
gbs_running,
max_num_updates,
train_logs,
opt_step_is_saved,
)
if finished_training:
break
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
time_deltas["post_fwd"] = timer2.delta()
time_deltas["iteration"] = timer2.elapsed()
# --------------------
# Validation loop
# --------------------
if (
self.config.hparams.do_validation
and not eval_is_done
and curr_opt_step != 0
and curr_opt_step % self.hparams.val_logging_opt_steps == 0
):
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
timer3 = Timer()
timer3.start()
gc.collect()
if self.accelerator.is_main_process and self.hparams.wandb_enable:
wandb.unwatch(self.dummy_module)
if self.activation_tracker:
self.activation_tracker.is_eval()
logger.info("** Starting validation **")
(
val_steps,
val_per_token_loss_acc,
val_num_images,
val_num_tokens,
val_image_to_text_ratio,
val_num_padding,
) = self._do_validation(progress, curr_opt_step)
# --------------------
# Log validation infos
# --------------------
self._log_validation(
val_steps,
curr_opt_step,
val_per_token_loss_acc,
val_num_images,
val_num_tokens,
val_image_to_text_ratio,
val_num_padding,
)
eval_is_done = True
logger.info("** Finished validation **")
if self.accelerator.is_main_process and self.hparams.wandb_enable:
wandb.watch(
self.dummy_module,
log="all",
log_freq=self.hparams.wandb_log_freq * self.hparams.grad_acc_size,
idx=0,
)
if self.activation_tracker:
self.activation_tracker.is_train()
gc.collect()
self.accelerator.wait_for_everyone()
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
print(f"[TIME] Validation: {format_secs_to_time(timer3.stop())}")
# restart timer from zero to avoid accounting for validation
timer = DeviceAgnosticTimer()
timer.start()
if self.hparams.timing_break_down:
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
# Finalize
print(
f'[TIME] Iteration {format_secs_to_sec_fractions(time_deltas["iteration"])}:'
f' {format_secs_to_sec_fractions(time_deltas["dl"]):>6} dl |'
f' {format_secs_to_sec_fractions(time_deltas["between_dl_fwd_bwd"]):>6} between'
f' dl/fwd-bwd | {format_secs_to_sec_fractions(time_deltas["fwd-bwd-step"]):>6} fwd/bwd'
f' | {format_secs_to_sec_fractions(time_deltas["post_fwd"]):>6} post'
)
# restart for __iter__
timer2.stop()
timer2.start()
if maybe_torch_profile_scheduler is not None and self.accelerator.is_main_process:
maybe_torch_profile_scheduler.step()
if not finished_training:
curr_epoch += 1
train_logs = self._end_of_epoch_reset_train_logs(train_logs)
self.train_loader.reset_state()
if curr_epoch == max_num_epochs:
self._save(
train_logs,
curr_opt_step,
curr_epoch,
gbs_running,
)
finished_training = True
logger.info("** Maximum number of epochs has been reached **")
break
if self.hparams.wandb_enable:
self.accelerator.end_training()
return train_logs