def forward()

in trainers/catex.py [0:0]


    def forward(self, image, perturb='none', ood_prompt=False, 
                return_feat=False, return_norm=True, posthoc=None):
        if len(image.shape) == 2:
            image_features = image
        else:
            assert len(image.shape) == 4
            image_features = self.image_encoder(image.type(self.dtype))
        if return_feat and not return_norm:
            ret_feat = image_features.detach().clone()
        
        if posthoc == 'apply_react':
            image_features = applyReAct(image_features)
        elif posthoc == 'apply_bats':
            image_features = applyBATS(image_features)
        elif posthoc == 'apply_ash':
            image_features = applyASH(image_features)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        if return_feat and return_norm:
            ret_feat = image_features.detach().clone()

        text_features = self.get_text_features(perturb, ood_prompt=ood_prompt)

        logits = self.get_logits(image_features, text_features)

        if return_feat:
            return ret_feat, text_features, logits
        else:
            return logits