def get_datasets()

in blink/candidate_retrieval/utils.py [0:0]


def get_datasets(get_test_dataset=False, get_pregenereted_candidates_wikidata_id=False):
    train_and_benchmarking_data_dir = "data/train_and_benchmark_data"
    datadir = os.path.join(
        train_and_benchmarking_data_dir, "generated/test_train_data/"
    )
    conll_path = os.path.join(
        train_and_benchmarking_data_dir, "basic_data/test_datasets/"
    )
    person_path = os.path.join(
        train_and_benchmarking_data_dir, "basic_data/p_e_m_data/persons.txt"
    )
    p_e_m_path = os.path.join(train_and_benchmarking_data_dir, "basic_data/p_e_m_data/")

    added_params = {
        "generate_cands": False,
        "generate_ments_and_cands": False,
        "candidate_generator_type": "p_e_m",
        "p_e_m_data_path": p_e_m_path,
    }
    conll = D.CoNLLDataset(datadir, person_path, conll_path, added_params)

    dev_datasets = [
        ("aida-A", conll.testA),
        ("aida-B", conll.testB),
        ("msnbc", conll.msnbc),
        ("aquaint", conll.aquaint),
        ("ace2004", conll.ace2004),
        ("clueweb", conll.clueweb),
        ("wikipedia", conll.wikipedia),
    ]

    if get_test_dataset:
        dev_datasets.append(("aida-train", conll.train))

    not_found = []
    total = 0
    for ds_name, dataset in dev_datasets:
        print("Processing dataset:", ds_name)
        for doc_name, content in dataset.items():
            for m in content:
                total += 1
                link = m["gold"][0]
                wikidata_id = get_wikidata_id_from_link_name(link)

                if wikidata_id is None:
                    not_found.append(m)

                m["gold_wikidata_id"] = wikidata_id

                if get_pregenereted_candidates_wikidata_id:
                    cands = []
                    for candidate in m["candidates"]:
                        link, prob = candidate
                        wikidata_id = get_wikidata_id_from_link_name(link)
                        cands.append((wikidata_id, link, prob))
                    m["candidates_wikidata_ids"] = cands

    print("Number of entities:", total)
    print(
        "Wikidata ID not found for:",
        len(not_found),
        "({:.3f} %)".format(len(not_found) * 1.0 / total),
    )

    return dev_datasets