def _step()

in pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py [0:0]


    def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
        assert (
            isinstance(batch, dict) and self.modality_key in batch and "label" in batch
        ), (
            f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
            + "'label' keys"
        )

        y_hat = self(batch[self.modality_key])
        if phase_type == "train":
            loss = self.loss(y_hat, batch["label"])
            self.log(
                f"Losses/{phase_type}_loss",
                loss,
                on_step=True,
                on_epoch=True,
                prog_bar=True,
            )
        else:
            loss = None

        ## TODO: Move MixUP transform metrics to sperate method.
        if (
            phase_type == "train"
            and self.batch_transform is not None
            and isinstance(self.batch_transform, MixVideoBatchWrapper)
        ):
            _top_max_k_vals, top_max_k_inds = torch.topk(
                batch["label"], 2, dim=1, largest=True, sorted=True
            )
            idx_top1 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 0]
            idx_top2 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 1]
            y_hat = y_hat.detach()
            y_hat[idx_top1] += y_hat[idx_top2]
            y_hat[idx_top2] = 0.0
            batch["label"] = top_max_k_inds[:, 0]

        pred = torch.nn.functional.softmax(y_hat, dim=-1)
        metrics_result = self._compute_metrics(pred, batch["label"], phase_type)
        self.log_dict(metrics_result, on_epoch=True)

        return loss