def prompt_ensemble()

in trainers/catex.py [0:0]


    def prompt_ensemble(self, learned_text_features=None):
        if learned_text_features is None:
            imagenet_templates = [  # for NPOS
                'a photo of a {}.',
                'a blurry photo of a {}.',
                'a black and white photo of a {}.',
                'a low contrast photo of a {}.',
                'a high contrast photo of a {}.',
                'a bad photo of a {}.',
                'a good photo of a {}.',
                'a photo of a small {}.',
                'a photo of a big {}.',
                'a photo of the {}.',
                'a blurry photo of the {}.',
                'a black and white photo of the {}.',
                'a low contrast photo of the {}.',
                'a high contrast photo of the {}.',
                'a bad photo of the {}.',
                'a good photo of the {}.',
                'a photo of the small {}.',
                'a photo of the big {}.',
            ]
        else:
            imagenet_templates = [  # for MCM
                'a photo of a {}.',
                'a blurry photo of a {}.',
                'a photo of many {}.',
                'a black and white photo of a {}.',
                'a photo of the large {}.',
                'a photo of the small {}.',
            ]
            lambd = 0.5

        dtype = self.text_encoder.dtype
        self.text_encoder = self.text_encoder.cuda()
        self.token_embedding = self.token_embedding.cuda()

        text_feature = []
        for ci, classname in enumerate(self.classnames):
            texts = [template.format(classname) for template in imagenet_templates]  # format with class
            texts = clip.tokenize(texts).cuda()  # tokenize
            embedding = self.token_embedding(texts).type(dtype)
            class_embeddings = self.text_encoder(embedding, texts) # embed with text encoder
            class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)

            if learned_text_features is not None:
                class_embeddings = torch.cat((class_embeddings, lambd * learned_text_features[ci:ci+1]))

            class_embedding = class_embeddings.mean(dim=0)
            class_embedding = class_embedding / class_embedding.norm()
            text_feature.append(class_embedding)
        text_feature = torch.stack(text_feature, dim=0).type(dtype)

        return text_feature