in dataflux_pytorch/benchmark/checkpointing/singlenode/train.py [0:0]
def training_step(self, batch: Tuple[Tensor, Tensor],
batch_idx: int) -> Tensor:
inputs, target = batch
output = self(inputs, target)
loss = torch.nn.functional.nll_loss(output, target.view(-1))
return loss