in trainers/catex.py [0:0]
def __init__(self, cfg):
super().__init__(cfg)
if cfg.TRAINER.OOD_TRAIN or self.is_large_ID():
if self.is_large_ID():
nsample = 1200 # 1200
self.id_pool = IDFeatPool(self.model.prompt_learner.n_cls, nsample, self.model.feat_dim, mode='npos', device='cuda:0')
if cfg.TRAINER.ID_FEAT_PRELOAD != '':
queue = torch.load(cfg.TRAINER.ID_FEAT_PRELOAD).to(self.id_pool.queue.device)
self.id_pool.queue = queue[:, :nsample, :]
self.id_pool.class_ptr += nsample
else:
from torch.utils.data import DataLoader, Subset
from ood.datasets import TinyImages, InfiniteDataLoader
assert 'cifar' in self.dm.dataset.dataset_name
data_root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
ood_set = TinyImages(data_root, transform=self.train_loader_x.dataset.transform)
self.ood_loader = InfiniteDataLoader(ood_set, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
shuffle=False, num_workers=self.train_loader_x.num_workers,
pin_memory=True) # drop_last=True,
from ood.losses import LogitNormLoss
self.ce_criterion = LogitNormLoss() if cfg.TRAINER.LOGIT_NORM else nn.CrossEntropyLoss()