def training_step()

in codes/net.py [0:0]


    def training_step(self, batch, batch_nb):
        # REQUIRED
        if self.hparams.train_mode == "nce":
            batched = True
            if batched:
                if self.hparams.corrupt_type == "all_context":
                    (
                        pred_true,
                        pred_false_scores,
                        target,
                    ) = self.calc_nce_scores_with_context_batched(batch)
                else:
                    pred_true, pred_false_scores, target = self.calc_nce_scores_batched(
                        batch
                    )
                pred_scores = torch.cat([pred_true, pred_false_scores], dim=0)
            else:
                pred_true, pred_false_scores = self.calc_nce_scores(batch)
                pred_scores = torch.cat([pred_true] + pred_false_scores, dim=0)
                true_weight = torch.ones_like(pred_true)
                true_weight = true_weight * len(pred_false_scores)
                target = torch.cat(
                    [torch.ones_like(pred_true)]
                    + [torch.zeros_like(pf) for pf in pred_false_scores],
                    dim=0,
                )
            loss = F.binary_cross_entropy(pred_scores, target, reduction="mean")
        else:
            if self.hparams.train_mode == "ref_score":
                pred_true, pred_false = self.calc_ref_scores(batch)
            else:
                pred_true, pred_false = self.calc_cont_scores(batch)
            device = pred_true.device
            t_loss = self.loss_fn(
                pred_true, torch.ones(pred_true.size(0), 1).to(device)
            )
            f_loss = self.loss_fn(
                pred_false, torch.zeros(pred_false.size(0), 1).to(device)
            )
            loss = t_loss + f_loss
        torch.cuda.empty_cache()
        # if batch_nb % 500 == 0:
        #     metrics = {
        #         "mode": "train",
        #         "minibatch": self.minibatch_step,
        #         "loss": loss.mean().item(),
        #         "true_loss": t_loss.mean().item(),
        #         "false_loss": f_loss.mean().item(),
        #         "epoch": self.trainer.current_epoch,
        #
        #     }
        #     self.minibatch_step += 1
        # self.logbook.write_metric_logs(metrics)
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss = loss.unsqueeze(0)
        tqdm_dict = {"train_loss": loss}
        output = OrderedDict(
            {"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
        )
        return output