in vilbert/datasets/vqa_dataset.py [0:0]
def _load_dataset(dataroot, name, clean_datasets):
"""Load entries
dataroot: root path of dataset
name: 'train', 'val', 'trainval', 'minsval'
"""
if name == "train" or name == "val":
question_path = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % name
)
questions = sorted(
json.load(open(question_path))["questions"], key=lambda x: x["question_id"]
)
answer_path = os.path.join(dataroot, "cache", "%s_target.pkl" % name)
answers = cPickle.load(open(answer_path, "rb"))
answers = sorted(answers, key=lambda x: x["question_id"])
elif name == "trainval":
question_path_train = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % "train"
)
questions_train = sorted(
json.load(open(question_path_train))["questions"],
key=lambda x: x["question_id"],
)
answer_path_train = os.path.join(dataroot, "cache", "%s_target.pkl" % "train")
answers_train = cPickle.load(open(answer_path_train, "rb"))
answers_train = sorted(answers_train, key=lambda x: x["question_id"])
question_path_val = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % "val"
)
questions_val = sorted(
json.load(open(question_path_val))["questions"],
key=lambda x: x["question_id"],
)
answer_path_val = os.path.join(dataroot, "cache", "%s_target.pkl" % "val")
answers_val = cPickle.load(open(answer_path_val, "rb"))
answers_val = sorted(answers_val, key=lambda x: x["question_id"])
questions = questions_train + questions_val[:-3000]
answers = answers_train + answers_val[:-3000]
elif name == "minval":
question_path_val = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % "val"
)
questions_val = sorted(
json.load(open(question_path_val))["questions"],
key=lambda x: x["question_id"],
)
answer_path_val = os.path.join(dataroot, "cache", "%s_target.pkl" % "val")
answers_val = cPickle.load(open(answer_path_val, "rb"))
answers_val = sorted(answers_val, key=lambda x: x["question_id"])
questions = questions_val[-3000:]
answers = answers_val[-3000:]
elif name == "test":
question_path_test = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2015_questions.json" % "test"
)
questions_test = sorted(
json.load(open(question_path_test))["questions"],
key=lambda x: x["question_id"],
)
questions = questions_test
elif name == "mteval":
question_path_train = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % "train"
)
questions_train = sorted(
json.load(open(question_path_train))["questions"],
key=lambda x: x["question_id"],
)
answer_path_train = os.path.join(dataroot, "cache", "%s_target.pkl" % "train")
answers_train = cPickle.load(open(answer_path_train, "rb"))
answers_train = sorted(answers_train, key=lambda x: x["question_id"])
question_path_val = os.path.join(
dataroot, "v2_OpenEnded_mscoco_%s2014_questions.json" % "val"
)
questions_val = sorted(
json.load(open(question_path_val))["questions"],
key=lambda x: x["question_id"],
)
answer_path_val = os.path.join(dataroot, "cache", "%s_target.pkl" % "val")
answers_val = cPickle.load(open(answer_path_val, "rb"))
answers_val = sorted(answers_val, key=lambda x: x["question_id"])
questions = questions_train
answers = answers_train
else:
assert False, "data split is not recognized."
if "test" in name:
entries = []
for question in questions:
entries.append(question)
elif name == "mteval":
entries = []
remove_ids = np.load(os.path.join(dataroot, "cache", "coco_test_ids.npy"))
remove_ids = [int(x) for x in remove_ids]
for question, answer in zip(questions, answers):
if int(question["image_id"]) in remove_ids:
entries.append(_create_entry(question, answer))
else:
assert_eq(len(questions), len(answers))
entries = []
remove_ids = []
if clean_datasets:
remove_ids = np.load(os.path.join(dataroot, "cache", "coco_test_ids.npy"))
remove_ids = [int(x) for x in remove_ids]
for question, answer in zip(questions, answers):
if "train" in name and int(question["image_id"]) in remove_ids:
continue
assert_eq(question["question_id"], answer["question_id"])
assert_eq(question["image_id"], answer["image_id"])
entries.append(_create_entry(question, answer))
return entries