in trainers/catex.py [0:0]
def prompt_ensemble(self, learned_text_features=None):
if learned_text_features is None:
imagenet_templates = [ # for NPOS
'a photo of a {}.',
'a blurry photo of a {}.',
'a black and white photo of a {}.',
'a low contrast photo of a {}.',
'a high contrast photo of a {}.',
'a bad photo of a {}.',
'a good photo of a {}.',
'a photo of a small {}.',
'a photo of a big {}.',
'a photo of the {}.',
'a blurry photo of the {}.',
'a black and white photo of the {}.',
'a low contrast photo of the {}.',
'a high contrast photo of the {}.',
'a bad photo of the {}.',
'a good photo of the {}.',
'a photo of the small {}.',
'a photo of the big {}.',
]
else:
imagenet_templates = [ # for MCM
'a photo of a {}.',
'a blurry photo of a {}.',
'a photo of many {}.',
'a black and white photo of a {}.',
'a photo of the large {}.',
'a photo of the small {}.',
]
lambd = 0.5
dtype = self.text_encoder.dtype
self.text_encoder = self.text_encoder.cuda()
self.token_embedding = self.token_embedding.cuda()
text_feature = []
for ci, classname in enumerate(self.classnames):
texts = [template.format(classname) for template in imagenet_templates] # format with class
texts = clip.tokenize(texts).cuda() # tokenize
embedding = self.token_embedding(texts).type(dtype)
class_embeddings = self.text_encoder(embedding, texts) # embed with text encoder
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
if learned_text_features is not None:
class_embeddings = torch.cat((class_embeddings, lambd * learned_text_features[ci:ci+1]))
class_embedding = class_embeddings.mean(dim=0)
class_embedding = class_embedding / class_embedding.norm()
text_feature.append(class_embedding)
text_feature = torch.stack(text_feature, dim=0).type(dtype)
return text_feature