def train_step()

in classy_vision/tasks/classification_task.py [0:0]


    def train_step(self):
        """Train step to be executed in train loop."""

        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        # Copy sample to GPU
        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        if self.mixup_transform is not None:
            sample = self.mixup_transform(sample)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        # only sync with DDP when we need to perform an optimizer step
        # an optimizer step can be skipped if gradient accumulation is enabled
        do_step = self._should_do_step()
        ctx_mgr_model = (
            self.distributed_model.no_sync()
            if self.distributed_model is not None and not do_step
            else contextlib.suppress()
        )
        ctx_mgr_loss = (
            self.distributed_loss.no_sync()
            if self.distributed_loss is not None and not do_step
            else contextlib.suppress()
        )

        with ctx_mgr_model, ctx_mgr_loss:
            # Forward pass
            with torch.enable_grad(), torch_amp_context:
                output = self.compute_model(sample)

                local_loss = self.compute_loss(output, sample)
                loss = local_loss.detach().clone()
                self.losses.append(loss.data.cpu().item())

                self.update_meters(output, sample)

            # Backwards pass + optimizer step
            self.run_optimizer(local_loss)

        self.num_updates += self.get_global_batchsize()

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )