def standard_train_step()

in vissl/trainer/train_steps/standard_train_step.py [0:0]


def standard_train_step(task):
    """
    Single training iteration loop of the model.

    Performs: data read, forward, loss computation, backward, optimizer step, parameter updates.

    Various intermediate steps are also performed:
    - logging the training loss, training eta, LR, etc to loggers
    - logging to tensorboard,
    - performing any self-supervised method specific operations (like in MoCo approach, the
    momentum encoder is updated), computing the scores in swav
    - checkpointing model if user wants to checkpoint in the middle
    of an epoch
    """
    assert isinstance(task, ClassyTask), "task is not instance of ClassyTask"

    # reset the last batch info at every step
    task.last_batch = LastBatchInfo()

    # We'll time train_step and some of its sections, and accumulate values
    # into perf_stats if it were defined in local_variables:
    perf_stats = task.perf_stats
    timer_train_step = PerfTimer("train_step_total", perf_stats)
    timer_train_step.start()

    # Process next sample
    with PerfTimer("read_sample", perf_stats):
        sample = next(task.data_iterator)

    sample = construct_sample_for_model(sample, task)

    # Only need gradients during training
    grad_context = torch.enable_grad() if task.train else torch.no_grad()
    ddp_context = (
        task.model.no_sync()
        if task.enable_manual_gradient_reduction
        else contextlib.suppress()
    )
    torch_amp_context = (
        torch.cuda.amp.autocast()
        if task.amp_type == AmpType.PYTORCH
        else contextlib.suppress()
    )

    with grad_context, ddp_context, torch_amp_context:
        # Forward pass of the model
        with PerfTimer("forward", perf_stats), record_function("forward"):
            if task.enable_manual_gradient_reduction:
                # Manually sync params and buffers for DDP.
                manual_sync_params(task.model)
            model_output = task.model(sample["input"])

        # If the model outputs only one tensor, we take it out of the list.
        if len(model_output) == 1:
            model_output = model_output[0]

        task.last_batch.sample = sample
        task.last_batch.model_output = model_output
        target = sample["target"]

        # Run hooks on forward pass
        task.run_hooks(SSLClassyHookFunctions.on_forward.name)

        # Compute loss
        with PerfTimer("loss_compute", perf_stats), record_function("loss_compute"):
            local_loss = task.loss(model_output, target)

        # Reduce the loss value across all nodes and gpus.
        with PerfTimer("loss_all_reduce", perf_stats):
            loss = local_loss.detach().clone()
            task.last_batch.loss = all_reduce_mean(loss)

        task.losses.append(task.last_batch.loss.data.cpu().item() * target.size(0))

        # Update meters
        if len(task.meters) > 0 and (
            (task.train and task.config["METERS"]["enable_training_meter"])
            or (not task.train)
        ):
            with PerfTimer("meters_update", perf_stats):
                if isinstance(model_output, list):
                    model_output_cpu = [x.cpu() for x in model_output]
                else:
                    model_output_cpu = model_output.cpu()

                for meter in task.meters:
                    meter.update(model_output_cpu, target.detach().cpu())

        task.last_batch.model_output = model_output
        task.last_batch.target = target

        # Update the iteration number, check loss is not NaN and measure batch time
        # now if it's a test phase since test phase doesn't have update step.
        task.run_hooks(SSLClassyHookFunctions.on_loss_and_meter.name)

    # Run backward now and update the optimizer
    if task.train:
        with PerfTimer("backward", perf_stats), record_function("backward"):

            task.optimizer.zero_grad()
            if task.amp_type == AmpType.APEX:
                with apex.amp.scale_loss(
                    local_loss, task.optimizer.optimizer
                ) as scaled_loss:
                    scaled_loss.backward()
                    if task.enable_manual_gradient_reduction:
                        manual_gradient_all_reduce(task.model)

            elif task.amp_type == AmpType.PYTORCH:
                task.amp_grad_scaler.scale(local_loss).backward()
                if task.enable_manual_gradient_reduction:
                    manual_gradient_all_reduce(task.model)
            else:
                local_loss.backward()
                if task.enable_manual_gradient_reduction:
                    manual_gradient_all_reduce(task.model)

        task.run_hooks(SSLClassyHookFunctions.on_backward.name)

        # Stepping the optimizer also updates learning rate, momentum etc
        # according to the schedulers (if any).
        with PerfTimer("optimizer_step", perf_stats), record_function("optimizer_step"):
            assert task.where < 1.0, (
                "Optimizer being called with where=1.0. This should not happen "
                "as where=1.0 means training is already finished. Please debug your "
                "training setup. A common issue is the data sampler resuming "
                "where you are checkpointing model at every iterations but not using "
                "the stateful data sampler OR there's an issue in properly resuming the "
                "data sampler."
            )
            if task.amp_type == AmpType.PYTORCH:
                task.amp_grad_scaler.step(task.optimizer, where=task.where)
                task.amp_grad_scaler.update()
            else:
                task.optimizer.step(where=task.where)

            # set the model grads to None to save memory
            # only in case of FSDP model
            if is_fsdp_model(task.model):
                zero_grad(task.model)

        task.run_hooks(SSLClassyHookFunctions.on_update.name)
        task.num_updates += task.get_global_batchsize()

    timer_train_step.stop()
    timer_train_step.record()

    return task