in blink/biencoder/eval_biencoder.py [0:0]
def main(params):
output_path = params["output_path"]
if not os.path.exists(output_path):
os.makedirs(output_path)
logger = utils.get_logger(params["output_path"])
# Init model
reranker = BiEncoderRanker(params)
tokenizer = reranker.tokenizer
model = reranker.model
device = reranker.device
cand_encode_path = params.get("cand_encode_path", None)
# candidate encoding is not pre-computed.
# load/generate candidate pool to compute candidate encoding.
cand_pool_path = params.get("cand_pool_path", None)
candidate_pool = load_or_generate_candidate_pool(
tokenizer,
params,
logger,
cand_pool_path,
)
candidate_encoding = None
if cand_encode_path is not None:
# try to load candidate encoding from path
# if success, avoid computing candidate encoding
try:
logger.info("Loading pre-generated candidate encode path.")
candidate_encoding = torch.load(cand_encode_path)
except:
logger.info("Loading failed. Generating candidate encoding.")
if candidate_encoding is None:
candidate_encoding = encode_candidate(
reranker,
candidate_pool,
params["encode_batch_size"],
silent=params["silent"],
logger=logger,
is_zeshel=params.get("zeshel", None)
)
if cand_encode_path is not None:
# Save candidate encoding to avoid re-compute
logger.info("Saving candidate encoding to file " + cand_encode_path)
torch.save(candidate_encoding, cand_encode_path)
test_samples = utils.read_dataset(params["mode"], params["data_path"])
logger.info("Read %d test samples." % len(test_samples))
test_data, test_tensor_data = data.process_mention_data(
test_samples,
tokenizer,
params["max_context_length"],
params["max_cand_length"],
context_key=params['context_key'],
silent=params["silent"],
logger=logger,
debug=params["debug"],
)
test_sampler = SequentialSampler(test_tensor_data)
test_dataloader = DataLoader(
test_tensor_data,
sampler=test_sampler,
batch_size=params["eval_batch_size"]
)
save_results = params.get("save_topk_result")
new_data = nnquery.get_topk_predictions(
reranker,
test_dataloader,
candidate_pool,
candidate_encoding,
params["silent"],
logger,
params["top_k"],
params.get("zeshel", None),
save_results,
)
if save_results:
save_data_dir = os.path.join(
params['output_path'],
"top%d_candidates" % params['top_k'],
)
if not os.path.exists(save_data_dir):
os.makedirs(save_data_dir)
save_data_path = os.path.join(save_data_dir, "%s.t7" % params['mode'])
torch.save(new_data, save_data_path)