in codes/net.py [0:0]
def validation_step(self, batch, batch_nb):
# OPTIONAL
if self.hparams.train_mode == "nce":
batched = True
if batched:
if self.hparams.corrupt_type == "all_context":
(
pred_true,
pred_false_scores,
target,
) = self.calc_nce_scores_with_context_batched(batch)
else:
pred_true, pred_false_scores, target = self.calc_nce_scores_batched(
batch
)
pred_false = pred_false_scores.mean(dim=-1)
pred_scores = torch.cat([pred_true, pred_false_scores], dim=0)
else:
pred_true, pred_false_scores = self.calc_nce_scores(batch)
pred_scores = torch.cat([pred_true] + pred_false_scores, dim=0)
true_weight = torch.ones_like(pred_true)
true_weight = true_weight * len(pred_false_scores)
target = torch.cat(
[torch.ones_like(pred_true)]
+ [torch.zeros_like(pf) for pf in pred_false_scores],
dim=0,
)
pred_false = torch.cat(pred_false_scores, dim=-1).mean(dim=-1)
loss = F.binary_cross_entropy(pred_scores, target, reduction="mean")
else:
if self.hparams.train_mode == "ref_score":
pred_true, pred_false = self.calc_ref_scores(batch)
else:
pred_true, pred_false = self.calc_cont_scores(batch)
device = pred_true.device
t_loss = self.loss_fn(
pred_true, torch.ones(pred_true.size(0), 1).to(device)
)
f_loss = self.loss_fn(
pred_false, torch.zeros(pred_false.size(0), 1).to(device)
)
loss = t_loss + f_loss
torch.cuda.empty_cache()
return OrderedDict(
{"val_loss": loss, "true_score": pred_true, "false_score": pred_false}
)