in src/resnet50.py [0:0]
def forward_head(self, x):
if self.projection_head is not None:
x = self.projection_head(x)
if self.l2norm:
x = nn.functional.normalize(x, dim=1, p=2)
if self.prototypes is not None:
return x, self.prototypes(x)
return x