in classy_vision/tasks/classification_task.py [0:0]
def train_step(self):
"""Train step to be executed in train loop."""
self.last_batch = None
# Process next sample
with Timer() as timer:
sample = next(self.data_iterator)
assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
f"Returned sample [{sample}] is not a map with 'input' and"
+ "'target' keys"
)
# Copy sample to GPU
target = sample["target"]
if self.use_gpu:
sample = recursive_copy_to_gpu(sample, non_blocking=True)
if self.mixup_transform is not None:
sample = self.mixup_transform(sample)
# Optional Pytorch AMP context
torch_amp_context = (
torch.cuda.amp.autocast()
if self.amp_type == AmpType.PYTORCH
else contextlib.suppress()
)
# only sync with DDP when we need to perform an optimizer step
# an optimizer step can be skipped if gradient accumulation is enabled
do_step = self._should_do_step()
ctx_mgr_model = (
self.distributed_model.no_sync()
if self.distributed_model is not None and not do_step
else contextlib.suppress()
)
ctx_mgr_loss = (
self.distributed_loss.no_sync()
if self.distributed_loss is not None and not do_step
else contextlib.suppress()
)
with ctx_mgr_model, ctx_mgr_loss:
# Forward pass
with torch.enable_grad(), torch_amp_context:
output = self.compute_model(sample)
local_loss = self.compute_loss(output, sample)
loss = local_loss.detach().clone()
self.losses.append(loss.data.cpu().item())
self.update_meters(output, sample)
# Backwards pass + optimizer step
self.run_optimizer(local_loss)
self.num_updates += self.get_global_batchsize()
# Move some data to the task so hooks get a chance to access it
self.last_batch = LastBatchInfo(
loss=loss,
output=output,
target=target,
sample=sample,
step_data={"sample_fetch_time": timer.elapsed_time},
)