in vissl/trainer/train_steps/standard_train_step.py [0:0]
def standard_train_step(task):
"""
Single training iteration loop of the model.
Performs: data read, forward, loss computation, backward, optimizer step, parameter updates.
Various intermediate steps are also performed:
- logging the training loss, training eta, LR, etc to loggers
- logging to tensorboard,
- performing any self-supervised method specific operations (like in MoCo approach, the
momentum encoder is updated), computing the scores in swav
- checkpointing model if user wants to checkpoint in the middle
of an epoch
"""
assert isinstance(task, ClassyTask), "task is not instance of ClassyTask"
# reset the last batch info at every step
task.last_batch = LastBatchInfo()
# We'll time train_step and some of its sections, and accumulate values
# into perf_stats if it were defined in local_variables:
perf_stats = task.perf_stats
timer_train_step = PerfTimer("train_step_total", perf_stats)
timer_train_step.start()
# Process next sample
with PerfTimer("read_sample", perf_stats):
sample = next(task.data_iterator)
sample = construct_sample_for_model(sample, task)
# Only need gradients during training
grad_context = torch.enable_grad() if task.train else torch.no_grad()
ddp_context = (
task.model.no_sync()
if task.enable_manual_gradient_reduction
else contextlib.suppress()
)
torch_amp_context = (
torch.cuda.amp.autocast()
if task.amp_type == AmpType.PYTORCH
else contextlib.suppress()
)
with grad_context, ddp_context, torch_amp_context:
# Forward pass of the model
with PerfTimer("forward", perf_stats), record_function("forward"):
if task.enable_manual_gradient_reduction:
# Manually sync params and buffers for DDP.
manual_sync_params(task.model)
model_output = task.model(sample["input"])
# If the model outputs only one tensor, we take it out of the list.
if len(model_output) == 1:
model_output = model_output[0]
task.last_batch.sample = sample
task.last_batch.model_output = model_output
target = sample["target"]
# Run hooks on forward pass
task.run_hooks(SSLClassyHookFunctions.on_forward.name)
# Compute loss
with PerfTimer("loss_compute", perf_stats), record_function("loss_compute"):
local_loss = task.loss(model_output, target)
# Reduce the loss value across all nodes and gpus.
with PerfTimer("loss_all_reduce", perf_stats):
loss = local_loss.detach().clone()
task.last_batch.loss = all_reduce_mean(loss)
task.losses.append(task.last_batch.loss.data.cpu().item() * target.size(0))
# Update meters
if len(task.meters) > 0 and (
(task.train and task.config["METERS"]["enable_training_meter"])
or (not task.train)
):
with PerfTimer("meters_update", perf_stats):
if isinstance(model_output, list):
model_output_cpu = [x.cpu() for x in model_output]
else:
model_output_cpu = model_output.cpu()
for meter in task.meters:
meter.update(model_output_cpu, target.detach().cpu())
task.last_batch.model_output = model_output
task.last_batch.target = target
# Update the iteration number, check loss is not NaN and measure batch time
# now if it's a test phase since test phase doesn't have update step.
task.run_hooks(SSLClassyHookFunctions.on_loss_and_meter.name)
# Run backward now and update the optimizer
if task.train:
with PerfTimer("backward", perf_stats), record_function("backward"):
task.optimizer.zero_grad()
if task.amp_type == AmpType.APEX:
with apex.amp.scale_loss(
local_loss, task.optimizer.optimizer
) as scaled_loss:
scaled_loss.backward()
if task.enable_manual_gradient_reduction:
manual_gradient_all_reduce(task.model)
elif task.amp_type == AmpType.PYTORCH:
task.amp_grad_scaler.scale(local_loss).backward()
if task.enable_manual_gradient_reduction:
manual_gradient_all_reduce(task.model)
else:
local_loss.backward()
if task.enable_manual_gradient_reduction:
manual_gradient_all_reduce(task.model)
task.run_hooks(SSLClassyHookFunctions.on_backward.name)
# Stepping the optimizer also updates learning rate, momentum etc
# according to the schedulers (if any).
with PerfTimer("optimizer_step", perf_stats), record_function("optimizer_step"):
assert task.where < 1.0, (
"Optimizer being called with where=1.0. This should not happen "
"as where=1.0 means training is already finished. Please debug your "
"training setup. A common issue is the data sampler resuming "
"where you are checkpointing model at every iterations but not using "
"the stateful data sampler OR there's an issue in properly resuming the "
"data sampler."
)
if task.amp_type == AmpType.PYTORCH:
task.amp_grad_scaler.step(task.optimizer, where=task.where)
task.amp_grad_scaler.update()
else:
task.optimizer.step(where=task.where)
# set the model grads to None to save memory
# only in case of FSDP model
if is_fsdp_model(task.model):
zero_grad(task.model)
task.run_hooks(SSLClassyHookFunctions.on_update.name)
task.num_updates += task.get_global_batchsize()
timer_train_step.stop()
timer_train_step.record()
return task