in utils_ranking.py [0:0]
def _parse_data(self, all_data, num_cand, index='both'):
outputs = []
for data in tqdm(all_data, total=len(all_data), desc="converting aes candidates"):
# process one question
if 'candidates' in data:
# data has been processed
logger.info("Data has been processed")
outputs = all_data
break
if index == 'both':
all_aes = data['passages'] + data['tables']
elif index in ['passages', 'tables']:
all_aes = data[index]
else:
raise NotImplementedError()
if len(all_aes) < num_cand:
logger.info("you do not have enough aes retrievals")
continue
# get indices for positive candidates
if index == 'both':
pos_indices = data['pos_index_passages'] + [int(x)+len(data['passages']) for x in data['pos_index_tables']]
elif index == 'passages':
pos_indices = data['pos_index_passages']
else:
pos_indices = data['pos_index_tables']
random.shuffle(pos_indices)
if len(pos_indices) == 0:
continue
# probability of sampling negative candidates
prob = [0. if i in pos_indices else 1./(len(all_aes)-len(pos_indices)) for i in range(len(all_aes))]
num_pos = 1
for pos_idx in pos_indices[:num_pos]:
# one example
try:
neg_indices = list(np.random.choice(len(all_aes), num_cand-1, replace=False, p=prob))
except:
logger.info(f"question: {data['question']}")
logger.info(f"answer: {data['answers']}") if 'answers' in data else logger.info(f"answer: {data['denotation']}")
break
all_indices = [pos_idx] + neg_indices
outputs.append(
{'qid': data['qid'],
'question': data['question'],
'candidates': [all_aes[i] for i in all_indices],
}
)
return outputs