def forward()

in models/base.py [0:0]


    def forward(self, x, mode='forward_only', **kwargs):
        p4 = self.forward_features(x)
        logits = self.forward_classifier(p4)

        ood_p4 = kwargs.pop('ood_data', None)
        if ood_p4 is not None:
            ood_logits = self.forward_classifier(ood_p4)
            logits = torch.cat((logits, ood_logits), dim=0)
            p4 = torch.cat((p4, ood_p4), dim=0)

        return_features = kwargs.pop('return_features', False) or self.return_features
        if mode == 'forward_only':
            return (logits, p4) if return_features else logits
        elif mode == 'calc_loss':
            res = self.calc_loss(logits, p4, **kwargs)
            ret_p4 = p4 if ood_p4 is None else p4[:-len(ood_p4)]
            return (*res, ret_p4) if return_features else res
        else:
            raise NotImplementedError(mode)