def training_step()

in dataflux_pytorch/benchmark/checkpointing/singlenode/train.py [0:0]


    def training_step(self, batch: Tuple[Tensor, Tensor],
                      batch_idx: int) -> Tensor:
        inputs, target = batch
        output = self(inputs, target)
        loss = torch.nn.functional.nll_loss(output, target.view(-1))
        return loss