in elq/main_dense.py [0:0]
def load_models(args, logger):
# load biencoder model
if logger: logger.info("Loading biencoder model")
try:
with open(args.biencoder_config) as json_file:
biencoder_params = json.load(json_file)
except json.decoder.JSONDecodeError:
with open(args.biencoder_config) as json_file:
for line in json_file:
line = line.replace("'", "\"")
line = line.replace("True", "true")
line = line.replace("False", "false")
line = line.replace("None", "null")
biencoder_params = json.loads(line)
break
biencoder_params["path_to_model"] = args.biencoder_model
biencoder_params["cand_token_ids_path"] = args.cand_token_ids_path
biencoder_params["eval_batch_size"] = getattr(args, 'eval_batch_size', 8)
biencoder_params["no_cuda"] = (not getattr(args, 'use_cuda', False) or not torch.cuda.is_available())
if biencoder_params["no_cuda"]:
biencoder_params["data_parallel"] = False
biencoder_params["load_cand_enc_only"] = False
if getattr(args, 'max_context_length', None) is not None:
biencoder_params["max_context_length"] = args.max_context_length
biencoder = load_biencoder(biencoder_params)
if biencoder_params["no_cuda"] and type(biencoder.model).__name__ == 'DataParallel':
biencoder.model = biencoder.model.module
elif not biencoder_params["no_cuda"] and type(biencoder.model).__name__ != 'DataParallel':
biencoder.model = torch.nn.DataParallel(biencoder.model)
# load candidate entities
if logger: logger.info("Loading candidate entities")
(
candidate_encoding,
indexer,
id2title,
id2text,
id2wikidata,
) = _load_candidates(
args.entity_catalogue, args.entity_encoding,
args.faiss_index, args.index_path, logger=logger,
)
return (
biencoder,
biencoder_params,
candidate_encoding,
indexer,
id2title,
id2text,
id2wikidata,
)