in scripts/prepare_hybrid_ranking_data.py [0:0]
def parse_data(all_data, num_cand, index='both'):
outputs = []
for data in tqdm(all_data, total=len(all_data), desc="converting aes candidates"):
# process one question
all_aes = data['passages'] + data['tables']
if len(all_aes) < num_cand:
logger.info("you do not have enough aes retrievals")
continue
pos_indices = data['pos_index_passages'] + [int(x)+len(data['passages']) for x in data['pos_index_tables']]
random.shuffle(pos_indices)
if len(pos_indices) == 0:
continue
prob = [0. if i in pos_indices else 1./(len(all_aes)-len(pos_indices)) for i in range(len(all_aes))]
num_pos = 1
for pos_idx in pos_indices[:num_pos]:
# one example
try:
neg_indices = list(np.random.choice(len(all_aes), num_cand-1, replace=False, p=prob))
except:
logger.info(f"question: {data['question']}")
logger.info(f"answer: {data['answers']}") if 'answers' in data else logger.info(f"answer: {data['denotation']}")
break
all_indices = [pos_idx] + neg_indices
outputs.append(
{'qid': data['qid'],
'question': data['question'],
'candidates': [all_aes[i] for i in all_indices],
}
)
return outputs