in models/criterion.py [0:0]
def forward(self, preds, targets, weight=None):
if isinstance(preds, list):
N = len(preds)
if weight is None:
weight = preds[0].new_ones(1)
errs = [self._forward(preds[n], targets[n], weight[n])
for n in range(N)]
err = torch.mean(torch.stack(errs))
elif isinstance(preds, torch.Tensor):
if weight is None:
weight = preds.new_ones(1)
err = self._forward(preds, targets, weight)
return err