in datasets.py [0:0]
def __init__(self, dataset, root, metadata):
self.dataset = dataset
self.root = root
if self.dataset == 'yfcc15m':
with open(metadata, 'rb') as f:
self.samples = pickle.load(f)
elif self.dataset == 'coco':
samples = defaultdict(list)
with open(metadata) as f:
annotations = json.load(f)['annotations']
for ann in annotations:
samples[ann['image_id']].append(ann['caption'])
self.samples = [(k, v) for k, v in samples.items()]
elif self.dataset == 'cc12m' or self.dataset == 'cc3m':
self.samples = np.load(metadata, allow_pickle=True)
elif self.dataset == 'redcaps':
with open(metadata) as f:
annotations = json.load(f)
self.samples = [(ann['image_id'], ann['subreddit'], ann['caption']) for ann in annotations]