in pytorchvideo_trainer/pytorchvideo_trainer/module/video_classification.py [0:0]
def _step(self, batch: Batch, batch_idx: int, phase_type: str) -> Dict[str, Any]:
assert (
isinstance(batch, dict) and self.modality_key in batch and "label" in batch
), (
f"Returned batch [{batch}] is not a map with '{self.modality_key}' and"
+ "'label' keys"
)
y_hat = self(batch[self.modality_key])
if phase_type == "train":
loss = self.loss(y_hat, batch["label"])
self.log(
f"Losses/{phase_type}_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
)
else:
loss = None
## TODO: Move MixUP transform metrics to sperate method.
if (
phase_type == "train"
and self.batch_transform is not None
and isinstance(self.batch_transform, MixVideoBatchWrapper)
):
_top_max_k_vals, top_max_k_inds = torch.topk(
batch["label"], 2, dim=1, largest=True, sorted=True
)
idx_top1 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 0]
idx_top2 = torch.arange(batch["label"].shape[0]), top_max_k_inds[:, 1]
y_hat = y_hat.detach()
y_hat[idx_top1] += y_hat[idx_top2]
y_hat[idx_top2] = 0.0
batch["label"] = top_max_k_inds[:, 0]
pred = torch.nn.functional.softmax(y_hat, dim=-1)
metrics_result = self._compute_metrics(pred, batch["label"], phase_type)
self.log_dict(metrics_result, on_epoch=True)
return loss