trainers/catex.py (834 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp import random import json import os from glob import glob from tqdm import tqdm from contextlib import nullcontext import shutil import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from synthesis.feature_sample import IDFeatPool from ood.posthoc import applyReAct, applyBATS, applyASH from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer _tokenizer = _Tokenizer() perturb_methods = ['neg', 'zero', 'randn', 'randn_add', 'swap'] # def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") model = clip.build_model(state_dict or model.state_dict()) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts): ctx_len = prompts.shape[1] # TODO: compatible for dynamic context length x = prompts + self.positional_embedding.type(self.dtype)[:ctx_len] x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class PromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.CATEX.N_CTX ctx_init = cfg.TRAINER.CATEX.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" self.adjust_cls_promt = False self.cfg = cfg ctx_common = None if ctx_init and 'ensemble' not in ctx_init: # use given words to initialize context vectors ctx_init = ctx_init.replace("_", " ") n_ctx = len(ctx_init.split(" ")) prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] prompt_prefix = ctx_init else: # random initialization if cfg.TRAINER.CATEX.CSC: print("Initializing class-specific contexts") ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) else: print("Initializing a generic context") ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") self.ctx = nn.Parameter(ctx_vectors) # to be optimized self.ctx_cm = nn.Parameter(ctx_common) if ctx_common is not None else None if cfg.TRAINER.OOD_PROMPT: if cfg.TRAINER.OOD_PROMPT_NUM > 1: self.ctx_ood = [] for _ in range(cfg.TRAINER.OOD_PROMPT_NUM): ctx_ood = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_ood, std=0.02) self.ctx_ood.append(nn.Parameter(ctx_ood)) self.ctx_ood = nn.ParameterList(self.ctx_ood) else: ## TODO: compatible for pre-trained weights ctx_ood = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_ood, std=0.02) self.ctx_ood = nn.Parameter(ctx_ood) classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS if self.adjust_cls_promt: self.token_suffix = nn.Parameter(embedding[:, 1 + n_ctx :, :]) else: self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens self.class_token_position = cfg.TRAINER.CATEX.CLASS_TOKEN_POSITION def forward(self, perturb='none', ood_prompt=False, ood_prompt_idx=None): # ctx = self.ctx if ood_prompt: assert perturb == 'none', perturb if ood_prompt_idx is None: assert self.cfg.TRAINER.OOD_PROMPT_NUM == 1 ctx = self.ctx_ood else: ctx = self.ctx_ood[ood_prompt_idx] else: ctx = self.perturb_prompt(perturb) if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) if self.ctx_cm is not None: ctx_cm = self.ctx_cm.expand(self.n_cls, -1, -1) ctx = torch.cat((ctx_cm, ctx), dim=1) prefix = self.token_prefix suffix = self.token_suffix if self.class_token_position == "end": prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) suffix, # (n_cls, *, dim) ], dim=1, ) elif self.class_token_position == "middle": half_n_ctx = self.n_ctx // 2 prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) ctx_i_half1, # (1, n_ctx//2, dim) class_i, # (1, name_len, dim) ctx_i_half2, # (1, n_ctx//2, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) elif self.class_token_position == "front": prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i = ctx[i : i + 1, :, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) class_i, # (1, name_len, dim) ctx_i, # (1, n_ctx, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) else: raise ValueError return prompts def perturb_prompt(self, method='none'): if method == 'none': return self.ctx coef_dict = { 'neg': [-1., 0.], 'zero': [0., 0.], 'randn': [0., 1.], 'randn_add': [1., 1.], 'swap': [0., 1.] } assert method in coef_dict ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) ncls, nctx, ndim = ctx.shape assert nctx > 1 for i in range(self.cfg.TRAINER.ID_PERTUR_NUM): # perturb one prompt for each class ctx_ind = torch.randint(0, nctx, size=(ncls,)) cls_ind = torch.arange(ncls) src_mask = torch.ones((ncls, nctx, 1)).type_as(ctx) src_mask[cls_ind, ctx_ind] = 0. src_ctx = ctx[cls_ind, ctx_ind].detach() if method == 'swap': ori_ind = torch.arange(ncls) while True: rand_ind = torch.randperm(ncls) if (ori_ind != rand_ind).all(): noise = src_ctx[rand_ind] break else: noise = torch.randn_like(ctx[:, 0, :]) src_coef, noise_coef = coef_dict[method] perturb = torch.zeros_like(ctx) perturb[cls_ind, ctx_ind] = src_coef * src_ctx + noise_coef * noise ctx = ctx * src_mask + perturb * (1. - src_mask) return ctx class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.classnames = classnames self.token_embedding = clip_model.token_embedding self.prompt_learner = PromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale if not cfg.TRAINER.OOD_TEST else torch.zeros_like(clip_model.logit_scale) self.dtype = clip_model.dtype self.feat_dim = clip_model.text_projection.data.shape[1] self.text_feature_ensemble = self.prompt_ensemble() if cfg.TRAINER.CATEX.CTX_INIT == 'ensemble' else None @torch.no_grad() def prompt_ensemble(self, learned_text_features=None): if learned_text_features is None: imagenet_templates = [ # for NPOS 'a photo of a {}.', 'a blurry photo of a {}.', 'a black and white photo of a {}.', 'a low contrast photo of a {}.', 'a high contrast photo of a {}.', 'a bad photo of a {}.', 'a good photo of a {}.', 'a photo of a small {}.', 'a photo of a big {}.', 'a photo of the {}.', 'a blurry photo of the {}.', 'a black and white photo of the {}.', 'a low contrast photo of the {}.', 'a high contrast photo of the {}.', 'a bad photo of the {}.', 'a good photo of the {}.', 'a photo of the small {}.', 'a photo of the big {}.', ] else: imagenet_templates = [ # for MCM 'a photo of a {}.', 'a blurry photo of a {}.', 'a photo of many {}.', 'a black and white photo of a {}.', 'a photo of the large {}.', 'a photo of the small {}.', ] lambd = 0.5 dtype = self.text_encoder.dtype self.text_encoder = self.text_encoder.cuda() self.token_embedding = self.token_embedding.cuda() text_feature = [] for ci, classname in enumerate(self.classnames): texts = [template.format(classname) for template in imagenet_templates] # format with class texts = clip.tokenize(texts).cuda() # tokenize embedding = self.token_embedding(texts).type(dtype) class_embeddings = self.text_encoder(embedding, texts) # embed with text encoder class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) if learned_text_features is not None: class_embeddings = torch.cat((class_embeddings, lambd * learned_text_features[ci:ci+1])) class_embedding = class_embeddings.mean(dim=0) class_embedding = class_embedding / class_embedding.norm() text_feature.append(class_embedding) text_feature = torch.stack(text_feature, dim=0).type(dtype) return text_feature def get_text_features(self, perturb='none', ood_prompt=False, ood_prompt_idx=None, return_norm=True): if self.text_feature_ensemble is not None and ood_prompt is False and perturb == 'none': assert return_norm text_features = self.text_feature_ensemble else: prompts = self.prompt_learner(perturb, ood_prompt=ood_prompt, ood_prompt_idx=ood_prompt_idx) tokenized_prompts = self.tokenized_prompts text_features = self.text_encoder(prompts, tokenized_prompts) if return_norm: text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features def get_all_ood_text_features(self): x = [] for ood_prompt_idx in range(self.prompt_learner.cfg.TRAINER.OOD_PROMPT_NUM): x.append(self.get_text_features(ood_prompt=True, ood_prompt_idx=ood_prompt_idx)) return torch.stack(x, dim=1) # shape(1000,5,512) def get_logits(self, image_features, text_features, logit_scale=None): if logit_scale is None: logit_scale = self.logit_scale.exp() if text_features.dim() == 2: logits = image_features.float() @ text_features.float().t() else: n = text_features.size(0) logits = torch.bmm(image_features.unsqueeze(0).repeat(n,1,1), text_features.transpose(1,2)).max(dim=0)[0] #.mean(dim=0) return logits * logit_scale def forward(self, image, perturb='none', ood_prompt=False, return_feat=False, return_norm=True, posthoc=None): if len(image.shape) == 2: image_features = image else: assert len(image.shape) == 4 image_features = self.image_encoder(image.type(self.dtype)) if return_feat and not return_norm: ret_feat = image_features.detach().clone() if posthoc == 'apply_react': image_features = applyReAct(image_features) elif posthoc == 'apply_bats': image_features = applyBATS(image_features) elif posthoc == 'apply_ash': image_features = applyASH(image_features) image_features = image_features / image_features.norm(dim=-1, keepdim=True) if return_feat and return_norm: ret_feat = image_features.detach().clone() text_features = self.get_text_features(perturb, ood_prompt=ood_prompt) logits = self.get_logits(image_features, text_features) if return_feat: return ret_feat, text_features, logits else: return logits def calc_prompt_loss(self, clean_feat=None, perturb_feat=None): l_intra, l_inter = 0., 0. # # 1. class-specific prompts should not be similar to class-name prompts # csc_prompts = F.normalize(self.prompt_learner.ctx, p=2, dim=-1) # cls_prompts = F.normalize(self.prompt_learner.token_suffix[:, :max(self.prompt_learner.name_lens), :], p=2, dim=-1) # prompts = torch.cat((csc_prompts, cls_prompts), dim=1) # shape(ncls, nctx, ndim) # similarity = torch.bmm(prompts, prompts.transpose(1,2)) # shape(ncls, nctx, nctx) # diag = torch.arange(similarity.shape[1]) # similarity[:, diag, diag] = -1. # l_intra += similarity.max(dim=-1)[0].relu().mean() # 2. prompts should obviously affect the text-feature if clean_feat is None: clean_feat = self.get_text_features() if perturb_feat is None: # with torch.no_grad(): perturb_feat = self.get_text_features(random.choice(perturb_methods)) similarity = (perturb_feat * clean_feat).sum(dim=1) l_inter += (similarity - 0.8).relu().mean() return l_intra + l_inter def calc_ood_prompt_loss(self, image, logits, label): perturb_logits = self.forward(image, perturb=random.choice(perturb_methods)) bi = torch.arange(image.shape[0]) intra_loss = (perturb_logits[bi, label] - logits[bi, label]).relu().mean() inter_loss = -(perturb_logits.mean(1) - torch.logsumexp(perturb_logits, dim=1)).mean() return intra_loss + inter_loss @TRAINER_REGISTRY.register() class CATEX(TrainerX): """Context Optimization (CATEX). Learning to Prompt for Vision-Language Models https://arxiv.org/abs/2109.01134 """ def __init__(self, cfg): super().__init__(cfg) if cfg.TRAINER.OOD_TRAIN or self.is_large_ID(): if self.is_large_ID(): nsample = 1200 # 1200 self.id_pool = IDFeatPool(self.model.prompt_learner.n_cls, nsample, self.model.feat_dim, mode='npos', device='cuda:0') if cfg.TRAINER.ID_FEAT_PRELOAD != '': queue = torch.load(cfg.TRAINER.ID_FEAT_PRELOAD).to(self.id_pool.queue.device) self.id_pool.queue = queue[:, :nsample, :] self.id_pool.class_ptr += nsample else: from torch.utils.data import DataLoader, Subset from ood.datasets import TinyImages, InfiniteDataLoader assert 'cifar' in self.dm.dataset.dataset_name data_root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) ood_set = TinyImages(data_root, transform=self.train_loader_x.dataset.transform) self.ood_loader = InfiniteDataLoader(ood_set, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, shuffle=False, num_workers=self.train_loader_x.num_workers, pin_memory=True) # drop_last=True, from ood.losses import LogitNormLoss self.ce_criterion = LogitNormLoss() if cfg.TRAINER.LOGIT_NORM else nn.CrossEntropyLoss() def is_large_ID(self): # return True return 'imagenet' in self.dm.dataset.dataset_name def check_cfg(self, cfg): assert cfg.TRAINER.CATEX.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.CATEX.PREC == "fp32" or cfg.TRAINER.CATEX.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in both the image and the text encoder") for name, param in self.model.named_parameters(): if "prompt_learner" not in name: param.requires_grad_(False) if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) # NOTE: only give prompt_learner/image_encoder to the optimizer self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) self.scaler = GradScaler() if cfg.TRAINER.CATEX.PREC == "amp" else None # # Note that multi-gpu training could be slow because CLIP's size is # # big, which slows down the copy operation in DataParallel # device_count = torch.cuda.device_count() # if device_count > 1: # print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") # self.model = nn.DataParallel(self.model) def build_data_loader(self): """Create essential data-related attributes. A re-implementation of this method must create the same attributes (self.dm is optional). """ super().build_data_loader() if self.cfg.TRAINER.FEAT_AS_INPUT and not self.cfg.TRAINER.OOD_TEST: from ood.datasets import CLIPFeatDataset self.load_shuffle = False self.train_loader_x = torch.utils.data.DataLoader( CLIPFeatDataset(self.dm.dataset.dataset_dir+'/clip_feat', self.start_epoch), batch_size=self.dm.train_loader_x.batch_size, shuffle=self.load_shuffle, num_workers=self.dm.train_loader_x.num_workers, pin_memory=True, drop_last=False, ) def calc_loss(self, logits, label, image=None, image_features=None, text_features=None, return_norm=False): nb, ncls = logits.shape # 1. classification if self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH \ and self.is_large_ID() and self.id_pool.ready(): if not return_norm: image_features = F.normalize(image_features, p=2, dim=1) ood_text_features = self.model.get_text_features(ood_prompt=True) if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS: logits = self.model.get_logits(image_features, torch.cat((text_features, ood_text_features))) logits[torch.arange(nb), label+ncls] = -10. # generally -inf loss = 2. * self.ce_criterion(logits, label) # 2. prompt perturbation perturbed_text_features = None if self.cfg.TRAINER.OOD_PROMPT and self.cfg.TRAINER.ID_PERTURB_LOSS and self.epoch >= self.cfg.TRAINER.START_EPOCH: with torch.no_grad(): perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods)) loss += 0.1 * self.model.calc_prompt_loss(text_features, perturbed_text_features) # 3. outlier exposure assert text_features is not None if self.is_large_ID(): if self.id_pool.ready() and self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH: if logits.size(0) < self.id_pool.queue.size(0): cls_mask = torch.unique(label).cpu() else: cls_mask = None if self.cfg.TRAINER.OOD_ANCHOR: if self.cfg.TRAINER.ID_PERTURB_LOSS and False: perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods)) logit_scale = self.model.logit_scale.exp() id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale id_neg_sim = (image_features * perturbed_text_features[label]).sum(dim=-1) * logit_scale loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1), torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) * 0.5 elif perturbed_text_features is None: with torch.no_grad(): perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods)) text_anchors = torch.stack((text_features, perturbed_text_features), dim=1).detach() else: text_anchors = None ood_features, ood_labels = self.id_pool.gen_ood(anchors=text_anchors, device=self.device, cls_mask=cls_mask) if self.cfg.TRAINER.OOD_OE_LOSS: ood_logits = self.model.get_logits(ood_features, text_features, logit_scale=1.) loss += 0.5 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean() if self.cfg.TRAINER.OOD_PROMPT: # ood_text_features = self.model.get_text_features(ood_prompt=True) if self.cfg.TRAINER.OOD_PROMPT_ORTH: assert self.cfg.TRAINER.OOD_PROMPT_NUM > 1 all_ood_text_features = self.model.get_all_ood_text_features() # (1000,5,512) x (1000,512,5) -> (1000,5,5) ood_sim_matrix = torch.bmm(all_ood_text_features, all_ood_text_features.transpose(1,2)) ood_text_num = ood_sim_matrix.shape[-1] zrange = torch.arange(ood_text_num) ood_sim_matrix[:, zrange, zrange] = 0. loss += 0.1 * ood_sim_matrix.mean() if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS: ood_logits = self.model.get_logits(ood_features, torch.cat((ood_text_features, text_features))) ood_logits[torch.arange(ood_logits.shape[0]), ood_labels+ncls] = -10. # generally -inf loss += 0.5 * self.ce_criterion(ood_logits, ood_labels) if self.cfg.TRAINER.OOD_PROMPT_MARGIN_LOSS: if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS: logit_scale = self.model.logit_scale.exp() else: logit_scale = 1. id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale id_neg_sim = (image_features * ood_text_features[label]).sum(dim=-1) * logit_scale ood_pos_sim = (ood_features * ood_text_features[ood_labels]).sum(dim=-1) * logit_scale ood_neg_sim = (ood_features * text_features[ood_labels]).sum(dim=-1) * logit_scale # id_pos_sim = (image_features @ text_features.T).max(dim=-1)[0] * logit_scale # id_neg_sim = (image_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale # ood_pos_sim = (ood_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale # ood_neg_sim = (ood_features @ text_features.T).max(dim=-1)[0] * logit_scale if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS: loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1), torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) + \ F.cross_entropy(torch.stack((ood_pos_sim, ood_neg_sim), dim=1), torch.zeros((len(ood_pos_sim),), dtype=torch.long, device=self.device)) else: loss += (id_neg_sim - id_pos_sim).relu().mean() + (ood_neg_sim - ood_pos_sim).relu().mean() else: ood_data, _ = next(self.ood_loader.__iter__()) ood_data = ood_data.to(self.device) ood_logits = self.model(ood_data) #/ self.model.logit_scale.exp() loss += 0.1 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean() return loss def forward_backward(self, batch): image, label = self.parse_batch_train(batch) prec = self.cfg.TRAINER.CATEX.PREC if prec == "amp": with autocast(): output = self.model(image) loss = F.cross_entropy(output, label) self.optim.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optim) self.scaler.update() else: output = self.model(image) loss = F.cross_entropy(output, label) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def forward_backward_ood(self, batch): image, label = self.parse_batch_train(batch) prec = self.cfg.TRAINER.CATEX.PREC return_norm = False if prec == "amp": with autocast(): img_feat, text_feat, output = \ self.model(image, return_feat=True, return_norm=return_norm) self.id_pool.update(img_feat.detach(), label) loss = self.calc_loss(output, label, image, img_feat, text_feat, return_norm=return_norm) self.optim.zero_grad() self.scaler.scale(loss).backward() self.scaler.step(self.optim) self.scaler.update() else: img_feat, text_feat, output = \ self.model(image, return_feat=True, return_norm=return_norm) self.id_pool.update(img_feat.detach(), label) loss = self.calc_loss(output, label, image, img_feat, text_feat, return_norm=return_norm) self.model_backward_and_update(loss) loss_summary = { "loss": loss.item(), "acc": compute_accuracy(output, label)[0].item(), } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print("Note that load_model() is skipped as no pretrained model is given") return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError('Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors model_dict = self._models[name].state_dict() if "token_prefix" in state_dict: del state_dict["token_prefix"] if "token_suffix" in state_dict: del state_dict["token_suffix"] assert all(k in model_dict for k in state_dict) print("Loading weights to {} {} " 'from "{}" (epoch = {})'.format(name, list(state_dict.keys()), model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) if self.cfg.TRAINER.CATEX.CTX_INIT: assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned' text_feature = self.model.get_text_features() self.model.text_feature_ensemble = self.model.prompt_ensemble(text_feature) def load_model_vanilla(self, directory, name, epoch=None): if not directory: print("Note that load_model() is skipped as no pretrained model is given") return # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError('Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] print("Loading vanilla weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False model = self.model.module if torch.cuda.device_count() > 1 else self.model getattr(model, name).load_state_dict(state_dict, strict=False) def before_epoch(self): if self.cfg.TRAINER.FEAT_AS_INPUT: if not self.load_shuffle: self.train_loader_x.dataset.load_data(self.epoch) @torch.no_grad() def test_ood(self, split=None, model_directory=''): """A generic OOD testing pipeline.""" from tqdm import tqdm import os import os.path as osp from torch.utils.data import DataLoader import numpy as np from ood.datasets import CLIPFeatDataset from ood.datasets import SCOODDataset, LargeOODDataset, SemanticOODDataset, ClassOODDataset from ood.metrics import get_msp_scores, get_measures self.set_model_mode("eval") self.evaluator.reset() if split is None: split = self.cfg.TEST.SPLIT if self.cfg.TRAINER.FEAT_AS_INPUT: feat_data_dir = self.dm.dataset.dataset_dir+'/clip_feat' if not osp.exists(feat_data_dir): self.cache_feat(split=split, is_ood=False) data_loader = DataLoader( CLIPFeatDataset(feat_data_dir, self.start_epoch, split='test'), batch_size=self.test_loader.batch_size, shuffle=False, num_workers=self.test_loader.num_workers, pin_memory=True, drop_last=False, ) else: if split == "val" and self.val_loader is not None: data_loader = self.val_loader else: split = "test" # in case val_loader is None data_loader = self.test_loader lab2cname = self.dm.dataset.lab2cname ood_cfg = { 'SCOOD': ['texture', 'svhn', 'cifar', 'tin', 'lsun', 'places365'], 'LargeOOD': ['inaturalist', 'sun', 'places', 'texture'], } data_root = osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT)) ood_type = 'SCOOD' if 'cifar' in self.dm.dataset.dataset_name else 'LargeOOD' # LargeOOD, ClassOOD if 'apply_' in self.cfg.TRAINER.OOD_INFER_OPTION: posthoc = self.cfg.TRAINER.OOD_INFER_OPTION else: posthoc = None if self.cfg.TRAINER.OOD_PROMPT: if self.cfg.TRAINER.OOD_PROMPT_NUM > 1: ood_text_features = torch.stack([self.model.get_text_features(ood_prompt=True, ood_prompt_idx=i) for i in range(self.cfg.TRAINER.OOD_PROMPT_NUM)]) else: ood_text_features = self.model.get_text_features(ood_prompt=True) if self.cfg.TRAINER.CATEX.CTX_INIT: assert self.cfg.TRAINER.CATEX.CTX_INIT == 'ensemble_learned' ood_text_features = self.model.prompt_ensemble(ood_text_features) min_thresh = 0.51 if any(flag in model_directory for flag in ['/imagenet/', '/imagenet100-MCM-SCTX8-Orth/']) else 0.5 self.model.text_feature_ensemble = self.model.get_text_features() print(f"Evaluate on the *{split}* set") if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res': save_dir = f'{model_directory}/restore' os.makedirs(save_dir, exist_ok=True) with open(f'{save_dir}/lab2cname.json', 'w+') as f: json.dump(lab2cname, f, indent=4) text_features = self.model.get_text_features() torch.save(text_features.cpu(), f'{save_dir}/in_text_features.pt') if self.cfg.TRAINER.OOD_PROMPT: torch.save(ood_text_features.cpu(), f'{save_dir}/ood_text_features.pt') im_feats, im_labels = [], [] if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res': resume_dir = 'weights/imagenet100-MCM/CATEX/vit_b16_ep50_-1shots/nctx16_cscTrue_ctpend/seed1/restore' resume_image_features = torch.load(f'{resume_dir}/in_image_features.pt').to(self.device) resume_image_labels = torch.load(f'{resume_dir}/in_labels.pt').to(self.device) resume_text_features = torch.load(f'{resume_dir}/in_text_features.pt').to(self.device) if self.cfg.TRAINER.OOD_PROMPT: resume_ood_text_features = torch.load(f'{resume_dir}/ood_text_features.pt').to(self.device) with open(f'{resume_dir}/lab2cname.json', 'r') as f: resume_lab2cname = json.load(f) resume_lab2cname = {int(k): v for k, v in resume_lab2cname.items()} label_offset = resume_image_labels.max().item() + 1 resume_image_labels += label_offset text_features = self.model.get_text_features() merged_text_features = torch.cat((text_features, resume_text_features), dim=0) if self.cfg.TRAINER.OOD_PROMPT: # TODO: not implemented merged_ood_text_features = torch.cat((ood_text_features, resume_ood_text_features), dim=0) score_list = [] base_acc, novel_acc = [], [] near_ood_flag = [] all_logits = [] for batch_idx, batch in enumerate(tqdm(data_loader)): input, label = self.parse_batch_test(batch) image_features, _, output = \ self.model(input, return_feat=True, return_norm=True, posthoc=posthoc) if self.cfg.TRAINER.OOD_PROMPT: ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.) # all_logits.append(torch.stack((output, ood_logits), dim=1)) if self.cfg.TRAINER.OOD_INFER_INTEGRATE: id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :] output *= id_score.clamp(min=min_thresh) else: ood_logits = None if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res': im_feats.append(image_features.cpu()) im_labels.append(label.cpu()) if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res': start = data_loader.batch_size * batch_idx end = start + input.shape[0] merged_image_features = torch.cat((image_features, resume_image_features[start:end]), dim=0) label = torch.cat((label, resume_image_labels[start:end]), dim=0) output = merged_image_features @ merged_text_features.t() acc = output.argmax(dim=1) == label base_acc.append(acc[input.shape[0]:].cpu()) novel_acc.append(acc[:input.shape[0]].cpu()) if hasattr(self.dm.dataset, 'valid_classes'): # if self.cfg.TRAINER.OOD_PROMPT: # raise NotImplementedError output[:, ~self.dm.dataset.valid_classes] = -1. scores = get_msp_scores(output[:, self.dm.dataset.valid_classes]) else: scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True) near_ood_flag.append(ood_flag) score_list.append(scores.detach().cpu().numpy()) self.evaluator.process(output, label) in_scores = np.concatenate(score_list, axis=0) results = self.evaluator.evaluate() if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None: print('NearOOD FPR:', torch.cat(near_ood_flag).sum().item() / len(in_scores)) if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res': torch.save(torch.cat(im_feats), f'{save_dir}/in_image_features.pt') torch.save(torch.cat(im_labels), f'{save_dir}/in_labels.pt') if len(all_logits): torch.save(torch.cat(all_logits), f'{save_dir}/in_logits_all.pt') if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res': print(f'Base: {torch.cat(base_acc).float().mean(): .4f}. Novel: {torch.cat(novel_acc).float().mean(): .4f}') auroc_list, aupr_list, fpr95_list = [], [], [] ood_tpr_list = [] save_lines = [] for ood_name in ood_cfg[ood_type]: ood_set = eval(f'{ood_type}Dataset')(osp.join(data_root, ood_type), id_name=self.dm.dataset.dataset_name, ood_name=ood_name, transform=self.test_loader.dataset.transform) if self.cfg.TRAINER.FEAT_AS_INPUT: feat_data_dir = f'{data_root}/{ood_type}/clip_feat/{ood_name}' if not osp.exists(feat_data_dir): self.cache_feat(split='test', is_ood=True) ood_loader = DataLoader( CLIPFeatDataset(feat_data_dir, epoch=None, split='test'), batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=data_loader.num_workers, pin_memory=True, drop_last=False, ) else: ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=data_loader.num_workers, drop_last=False, pin_memory=True) ood_score_list, sc_labels_list, ood_pred_list = [], [], [] near_ood_flag = [] all_logits = [] for batch_idx, batch in enumerate(tqdm(ood_loader)): if self.cfg.TRAINER.FEAT_AS_INPUT: images, sc_labels = self.parse_batch_test(batch) else: images, sc_labels = batch images = images.to(self.device) image_features, _, output = \ self.model(images, return_feat=True, return_norm=True, posthoc=posthoc) if self.cfg.TRAINER.OOD_PROMPT: ood_logits = self.model.get_logits(image_features, ood_text_features, logit_scale=1.) # all_logits.append(torch.stack((output, ood_logits), dim=1)) if self.cfg.TRAINER.OOD_INFER_INTEGRATE: id_score = F.softmax(torch.stack((output, ood_logits), dim=1), dim=1)[:, 0, :] output *= id_score.clamp(min=min_thresh) else: ood_logits = None if self.cfg.TRAINER.OOD_INFER_OPTION == 'resume_res': output = image_features @ merged_text_features.t() if hasattr(self.dm.dataset, 'valid_classes'): output[:, ~self.dm.dataset.valid_classes] = -1. scores = get_msp_scores(output[:, self.dm.dataset.valid_classes]) else: scores, ood_flag = get_msp_scores(output, ood_logits, self.cfg.TRAINER.OOD_INFER, ret_near_ood=True) near_ood_flag.append(ood_flag) ood_score_list.append(scores.detach().cpu().numpy()) sc_labels_list.append(sc_labels.cpu().numpy()) ood_pred_list.append(output.argmax(dim=1).cpu().numpy()) ood_scores = np.concatenate(ood_score_list, axis=0) sc_labels = np.concatenate(sc_labels_list, axis=0) ood_preds = np.concatenate(ood_pred_list, axis=0) fake_ood_scores = ood_scores[sc_labels>=0] real_ood_scores = ood_scores[sc_labels<0] real_in_scores = np.concatenate([in_scores, fake_ood_scores], axis=0) if 'cifar' in self.dm.dataset.dataset_name: # compatible with SCOOD auroc, aupr, fpr95, thresh = get_measures(real_ood_scores, real_in_scores) else: # compatible with NPOS auroc, aupr, fpr95, thresh = get_measures(-real_in_scores, -real_ood_scores) print('auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (auroc, aupr, fpr95)) save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % (ood_name, auroc, aupr, fpr95)) auroc_list.append(auroc) aupr_list.append(aupr) fpr95_list.append(fpr95) if self.cfg.TRAINER.OOD_PROMPT and len(near_ood_flag) and near_ood_flag[0] is not None: ood_tpr = torch.cat(near_ood_flag).sum().item() / len(ood_scores) print('NearOOD TPR: %.4f' % ood_tpr) ood_tpr_list.append(ood_tpr) if self.cfg.TRAINER.OOD_INFER_OPTION == 'save_res' and len(all_logits): torch.save(torch.cat(all_logits), f'{save_dir}/ood_{ood_name}_logits_all.pt') print('\nAverage: auroc: %.4f, aupr: %.4f, fpr95: %.4f' % (np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list))) save_lines.append('%10s auroc: %.4f, aupr: %.4f, fpr95: %.4f\n' % ('nAverage', np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr95_list))) if self.cfg.TRAINER.OOD_PROMPT and len(ood_tpr_list) > 1: print('Average: OOD-TPR: %.4f' % np.mean(ood_tpr_list)) if model_directory != '': if 'ClassOOD' == ood_type: res_list = np.stack((auroc_list, aupr_list, fpr95_list), axis=1).reshape(-1,) * 100 np.savetxt(f'{model_directory}/{ood_type}_results.csv', res_list, fmt='%.2f', delimiter=',') save_path = f'{model_directory}/{ood_type}_results.txt' with open(save_path, 'w+') as f: f.writelines(save_lines) return list(results.values())[0], auroc, aupr, fpr95 @torch.no_grad() def cache_feat(self, split='train', is_ood=True): """A generic OOD testing pipeline.""" self.set_model_mode("eval") self.evaluator.reset() if split == 'train': data_loader = self.train_loader_x max_epoch = self.max_epoch else: data_loader = self.test_loader max_epoch = 1 if is_ood: from ood.datasets import LargeOODDataset from torch.utils.data import DataLoader data_root = osp.join(osp.abspath(osp.expanduser(self.cfg.DATASET.ROOT)), 'LargeOOD') for ood_name in ['inaturalist', 'sun', 'places', 'texture']: ood_set = LargeOODDataset(data_root, id_name=self.dm.dataset.dataset_name, ood_name=ood_name, transform=self.test_loader.dataset.transform) ood_loader = DataLoader(ood_set, batch_size=self.cfg.DATALOADER.TEST.BATCH_SIZE, shuffle=False, num_workers=self.test_loader.num_workers, drop_last=False, pin_memory=True) save_dir = f'{data_root}/clip_feat/{ood_name}' os.makedirs(save_dir, exist_ok=True) features, labels, paths = [], [], [] cnt = 0 for input, label in tqdm(ood_loader, desc='Caching image features'): input = input.to(self.device) label = label.to(self.device) image_features = self.model.image_encoder(input.type(self.model.dtype)).detach() features.append(image_features.cpu()) labels.append(label.cpu()) for i in range(len(input)): paths.append(ood_set.samples[cnt+i][0]) cnt += len(input) torch.save(torch.cat(features).half(), f'{save_dir}/test_image_features.pt') torch.save(torch.cat(labels).half(), f'{save_dir}/test_labels.pt') with open(f'{save_dir}/test_paths.txt', 'w+') as f: f.writelines([p + '\n' for p in paths]) else: save_dir = f'{self.dm.dataset.dataset_dir}/clip_feat' os.makedirs(save_dir, exist_ok=True) for self.epoch in range(self.start_epoch, max_epoch): features, labels, paths = [], [], [] for batch_idx, batch in enumerate(tqdm(data_loader, desc=f"Caching image features: {split} {self.epoch+1}/{max_epoch}: ")): input = batch["img"].to(self.device) label = batch["label"].to(self.device) image_features = self.model.image_encoder(input.type(self.model.dtype)).detach() features.append(image_features.cpu()) labels.append(label.cpu()) paths.extend(batch["impath"]) torch.save(torch.cat(features).half(), f'{save_dir}/ep{self.epoch}_{split}_image_features.pt') torch.save(torch.cat(labels).half(), f'{save_dir}/ep{self.epoch}_{split}_labels.pt') with open(f'{save_dir}/ep{self.epoch}_{split}_paths.txt', 'w+') as f: f.writelines([p + '\n' for p in paths])