in src/fairseq/fairseq/criterions/binary_cross_entropy.py [0:0]
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
weights = None
if hasattr(model, 'get_target_weights') and not self.infonce:
weights = model.get_target_weights(target, net_output)
if torch.is_tensor(weights):
weights = weights.float()
losses = []
if self.infonce:
loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
else:
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss)
if self.loss_weights is not None and hasattr(model, "get_extra_losses"):
extra_losses = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(self.loss_weights), f'{len(extra_losses)}, {len(self.loss_weights)}'
for p, coef in zip(extra_losses, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
losses.append(p)
logging_output = {
'loss': loss.item() if reduce else loss,
'ntokens': sample_size,
'nsentences': logits.size(0),
'sample_size': sample_size,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f'loss_{i}'] = l.item()
if self.infonce:
with torch.no_grad():
if logits.numel() == 0:
corr = 0
count = 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
logging_output["correct"] = corr
logging_output["count"] = count
if log_pred:
logging_output['logits'] = logits.cpu().numpy()
logging_output['target'] = target.cpu().numpy()
return loss, sample_size, logging_output