in vilbert/datasets/gqa_dataset.py [0:0]
def _load_dataset(dataroot, name, clean_datasets):
"""Load entries
dataroot: root path of dataset
name: 'train', 'val', 'trainval', 'test'
"""
if name == "train" or name == "val":
items_path = os.path.join(dataroot, "cache", "%s_target.pkl" % name)
items = cPickle.load(open(items_path, "rb"))
items = sorted(items, key=lambda x: x["question_id"])
elif name == "trainval":
items_path = os.path.join(dataroot, "cache", "%s_target.pkl" % name)
items = cPickle.load(open(items_path, "rb"))
items = sorted(items, key=lambda x: x["question_id"])
items = items[:-3000]
elif name == "minval":
items_path = os.path.join(dataroot, "cache", "trainval_target.pkl")
items = cPickle.load(open(items_path, "rb"))
items = sorted(items, key=lambda x: x["question_id"])
items = items[-3000:]
elif name == "test":
items_path = os.path.join(dataroot, "testdev_balanced_questions.json")
items = json.load(open(items_path, "rb"))
else:
assert False, "data split is not recognized."
if "test" in name:
entries = []
for item in items:
it = items[item]
entry = {
"question_id": int(item),
"image_id": it["imageId"],
"question": it["question"],
}
entries.append(entry)
else:
entries = []
remove_ids = []
if clean_datasets:
remove_ids = np.load(os.path.join(dataroot, "cache", "genome_test_ids.npy"))
remove_ids = [int(x) for x in remove_ids]
for item in items:
if "train" in name and int(item["image_id"]) in remove_ids:
continue
entries.append(_create_entry(item))
return entries