def __init__()

in trainers/catex.py [0:0]


    def __init__(self, cfg):
        super().__init__(cfg)
        if cfg.TRAINER.OOD_TRAIN or self.is_large_ID():
            if self.is_large_ID():
                nsample = 1200 # 1200
                self.id_pool = IDFeatPool(self.model.prompt_learner.n_cls, nsample, self.model.feat_dim, mode='npos', device='cuda:0')
                if cfg.TRAINER.ID_FEAT_PRELOAD != '':
                    queue = torch.load(cfg.TRAINER.ID_FEAT_PRELOAD).to(self.id_pool.queue.device)
                    self.id_pool.queue = queue[:, :nsample, :]
                    self.id_pool.class_ptr += nsample
            else:
                from torch.utils.data import DataLoader, Subset
                from ood.datasets import TinyImages, InfiniteDataLoader

                assert 'cifar' in self.dm.dataset.dataset_name
                data_root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
                ood_set = TinyImages(data_root, transform=self.train_loader_x.dataset.transform)
                self.ood_loader = InfiniteDataLoader(ood_set, batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, 
                                                    shuffle=False, num_workers=self.train_loader_x.num_workers,
                                                    pin_memory=True)  # drop_last=True, 
            
            from ood.losses import LogitNormLoss
            self.ce_criterion = LogitNormLoss() if cfg.TRAINER.LOGIT_NORM else nn.CrossEntropyLoss()