in blink/candidate_retrieval/perform_and_evaluate_candidate_retrieval_multithreaded.py [0:0]
def main(args):
wall_start = time.time()
parameters = get_parameters(args)
print("Candidate generator parameters:", parameters)
datasets = utils.get_datasets(
args.include_aida_train, args.keep_pregenerated_candidates
)
if args.single_dataset:
datasets = [datasets[0]]
mentions = utils.get_list_of_mentions(datasets)
# NUM_TREADS = multiprocessing.cpu_count()
NUM_THREADS = args.num_threads
pool = ThreadPool(NUM_THREADS)
# Split the data into approximately equal parts and give one block to each thread
data_per_thread = split(mentions, NUM_THREADS)
if args.keep_pregenerated_candidates:
arguments = [
{
"id": idx,
"data": data_bloc,
"args": args,
"candidate_generator": Simple_Candidate_Generator(parameters),
"pregenereted_cands_data_fetcher": Pregenerated_Candidates_Data_Fetcher(
parameters
),
}
for idx, data_bloc in enumerate(data_per_thread)
]
else:
arguments = [
{
"id": idx,
"data": data_bloc,
"args": args,
"candidate_generator": Simple_Candidate_Generator(parameters),
}
for idx, data_bloc in enumerate(data_per_thread)
]
results = pool.map(run_thread, arguments)
# Merge the results
processed_mentions = []
for _id, mentions in results:
processed_mentions = processed_mentions + mentions
has_gold = 0
pool.terminate()
pool.join()
execution_time = (time.time() - wall_start) / 60
print("The execution took:", execution_time, " minutes")
# Evaluate the generation
evaluator = Evaluator(processed_mentions)
evaluator.candidate_generation(
save_gold_pos=True, save_pregenerated_gold_pos=args.keep_pregenerated_candidates
)
# Dump the data if the dump_mentions flag was set
if args.dump_mentions:
print("Dumping processed mentions")
# Create the directory for the mention dumps if it does not exist
dump_folder = args.dump_mentions_folder
os.makedirs(dump_folder, exist_ok=True)
dump_object = {}
dump_object["mentions"] = processed_mentions
dump_object["total_per_dataset"] = evaluator.total_per_dataset
dump_object["has_gold_per_dataset"] = evaluator.has_gold_per_dataset
dump_object["parameters"] = parameters
dump_object["args"] = args
dump_object["execution_time"] = execution_time
pickle.dump(
dump_object,
open(os.path.join(dump_folder, args.dump_file_id), "wb"),
protocol=4,
)
# evaluator.candidate_generation(max_rank=100)
return evaluator.recall