in ood/datasets.py [0:0]
def load_data(self, epoch=None):
if hasattr(self, 'epoch') and epoch == self.epoch:
return
prefix = ('' if epoch is None else f'ep{epoch}_') + self.split
if self.nshot > 0:
prefix = f'{self.nshot}shot_' + prefix
self.features = torch.load(f'{self.data_dir}/{prefix}_image_features.pt', map_location='cpu')
self.targets = torch.load(f'{self.data_dir}/{prefix}_labels.pt', map_location='cpu').long()
path_file = f'{self.data_dir}/{prefix}_paths.txt'
if os.path.exists(path_file):
with open(path_file, 'r') as f:
self.paths = f.read().splitlines()
self.epoch = epoch
if 'train' in prefix:
print(f'\nLoaded dataset from epoch {epoch}\n')
else:
self.paths = None