in scripts/adapet/ADAPET/src/data/RecordReader.py [0:0]
def read_dataset(self, split=None, is_eval=False):
'''
Read the dataset
:param split: partition of the dataset
:param is_eval:
'''
file = self._get_file(split)
data = []
with open(file, 'r') as f_in:
for line in f_in.readlines():
json_string = json.loads(line)
json_string_passage = json_string["passage"]
idx = json_string["idx"]
passage = json_string_passage["text"]
# Get dictionary mapping entity idx to entity
dict_entity_idx_2_name = {}
for entity in json_string_passage["entities"]:
start = entity["start"]
end = entity["end"]
word = passage[start:end + 1]
dict_entity_idx_2_name[(entity["start"], entity["end"])] = word
for qas in json_string["qas"]:
question = qas["query"]
qas_idx = qas["idx"]
# If data has solution
if "answers" in qas:
list_answers = qas["answers"]
# Get dictionary of entities in answers
dict_entity_idx_2_sol = {}
for answer in list_answers:
start = answer["start"]
end = answer["end"]
text = answer["text"]
dict_entity_idx_2_sol[(start, end)] = text
# Get all unique false entities for margin
set_false_entities = set()
for (enty_idx, enty) in dict_entity_idx_2_name.items():
if enty_idx not in dict_entity_idx_2_sol.keys():
set_false_entities.add(enty)
list_false_entities = list(set_false_entities)
# PET ensures each data gets exactly 1 true and the rest false during training
if split == "train" and not is_eval:
set_seen_enty = set()
for enty_idx, enty in dict_entity_idx_2_name.items():
# Create datapoints with each unique correct entity
if enty_idx in dict_entity_idx_2_sol:
# Only see each entity once
if enty not in set_seen_enty:
list_sample_false_entities = random.sample(list_false_entities, min(len(list_false_entities), self.config.max_num_lbl-1))
# Replace entity with [MASK]
masked_question = question.replace("@placeholder", "[MASK]")
set_seen_enty.add(enty)
dict_input = {"idx": idx, "passage": passage, "question": masked_question, "true_entity": enty, "false_entities": list_sample_false_entities}
dict_output = {"lbl": 0}
dict_input_output = {"input": dict_input, "output": dict_output}
data.append(dict_input_output)
# Construct evaluation sets with the lbl
else:
set_seen_enty = set()
for (enty_idx, enty) in dict_entity_idx_2_name.items():
if enty not in set_seen_enty:
set_seen_enty.add(enty)
# Replace entity with [MASK]
masked_question = question.replace("@placeholder", "[MASK]")
# Compute label for evaluation
label = [0 if enty not in list(dict_entity_idx_2_sol.values()) else 1 for enty in set_seen_enty]
dict_input = {"idx": idx, "passage": passage, "question": masked_question,
"candidate_entity": list(set_seen_enty)}
dict_output = {"lbl": label}
dict_input_output = {"input": dict_input, "output": dict_output}
data.append(dict_input_output)
else:
# Test set without labels
set_seen_enty = set()
for (enty_idx, enty) in dict_entity_idx_2_name.items():
if enty not in set_seen_enty:
set_seen_enty.add(enty)
# Replace entity with [MASK]
masked_question = question.replace("@placeholder", "[MASK]")
dict_input = {"idx": idx, "passage": passage, "question": masked_question,
"candidate_entity": list(set_seen_enty), "qas_idx": qas_idx}
dict_output = {"lbl": -1}
dict_input_output = {"input": dict_input, "output": dict_output}
data.append(dict_input_output)
data = np.asarray(data)
return data