in models/base.py [0:0]
def forward(self, x, mode='forward_only', **kwargs):
p4 = self.forward_features(x)
logits = self.forward_classifier(p4)
ood_p4 = kwargs.pop('ood_data', None)
if ood_p4 is not None:
ood_logits = self.forward_classifier(ood_p4)
logits = torch.cat((logits, ood_logits), dim=0)
p4 = torch.cat((p4, ood_p4), dim=0)
return_features = kwargs.pop('return_features', False) or self.return_features
if mode == 'forward_only':
return (logits, p4) if return_features else logits
elif mode == 'calc_loss':
res = self.calc_loss(logits, p4, **kwargs)
ret_p4 = p4 if ood_p4 is None else p4[:-len(ood_p4)]
return (*res, ret_p4) if return_features else res
else:
raise NotImplementedError(mode)