in trainers/catex.py [0:0]
def forward(self, image, perturb='none', ood_prompt=False,
return_feat=False, return_norm=True, posthoc=None):
if len(image.shape) == 2:
image_features = image
else:
assert len(image.shape) == 4
image_features = self.image_encoder(image.type(self.dtype))
if return_feat and not return_norm:
ret_feat = image_features.detach().clone()
if posthoc == 'apply_react':
image_features = applyReAct(image_features)
elif posthoc == 'apply_bats':
image_features = applyBATS(image_features)
elif posthoc == 'apply_ash':
image_features = applyASH(image_features)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
if return_feat and return_norm:
ret_feat = image_features.detach().clone()
text_features = self.get_text_features(perturb, ood_prompt=ood_prompt)
logits = self.get_logits(image_features, text_features)
if return_feat:
return ret_feat, text_features, logits
else:
return logits