def load_models()

in elq/main_dense.py [0:0]


def load_models(args, logger):
    # load biencoder model
    if logger: logger.info("Loading biencoder model")
    try:
        with open(args.biencoder_config) as json_file:
            biencoder_params = json.load(json_file)
    except json.decoder.JSONDecodeError:
        with open(args.biencoder_config) as json_file:
            for line in json_file:
                line = line.replace("'", "\"")
                line = line.replace("True", "true")
                line = line.replace("False", "false")
                line = line.replace("None", "null")
                biencoder_params = json.loads(line)
                break
    biencoder_params["path_to_model"] = args.biencoder_model
    biencoder_params["cand_token_ids_path"] = args.cand_token_ids_path
    biencoder_params["eval_batch_size"] = getattr(args, 'eval_batch_size', 8)
    biencoder_params["no_cuda"] = (not getattr(args, 'use_cuda', False) or not torch.cuda.is_available())
    if biencoder_params["no_cuda"]:
        biencoder_params["data_parallel"] = False
    biencoder_params["load_cand_enc_only"] = False
    if getattr(args, 'max_context_length', None) is not None:
        biencoder_params["max_context_length"] = args.max_context_length
    biencoder = load_biencoder(biencoder_params)
    if biencoder_params["no_cuda"] and type(biencoder.model).__name__ == 'DataParallel':
        biencoder.model = biencoder.model.module
    elif not biencoder_params["no_cuda"] and type(biencoder.model).__name__ != 'DataParallel':
        biencoder.model = torch.nn.DataParallel(biencoder.model)

    # load candidate entities
    if logger: logger.info("Loading candidate entities")

    (
        candidate_encoding,
        indexer,
        id2title,
        id2text,
        id2wikidata,
    ) = _load_candidates(
        args.entity_catalogue, args.entity_encoding,
        args.faiss_index, args.index_path, logger=logger,
    )

    return (
        biencoder,
        biencoder_params,
        candidate_encoding,
        indexer,
        id2title,
        id2text,
        id2wikidata,
    )