def get_list_of_mentions()

in blink/candidate_retrieval/utils.py [0:0]


def get_list_of_mentions(dev_datasets):
    mentions = []

    total_invalid = 0
    total_valid = 0

    for ds_name, dataset in dev_datasets:
        invalid = 0
        valid = 0

        print("Processing dataset:", ds_name)
        for doc_name, content in dataset.items():
            sentences = content[0]["conll_doc"]["sentences"]
            for m in content:
                gold_wikidata_id = m["gold_wikidata_id"]
                left_context, right_context = m["context"]

                m["mention_orig"] = m["mention"]
                m["mention"] = solr_escape(m["mention"])

                if left_context != "EMPTYCTXT":
                    left_context_orig = left_context
                    left_context = solr_escape(left_context)
                else:
                    left_context = ""

                if right_context != "EMPTYCTXT":
                    right_context_orig = right_context
                    right_context = solr_escape(right_context)
                else:
                    right_context = ""

                m["left_context_orig"] = left_context_orig
                m["right_context_orig"] = right_context_orig

                m["query_context"] = "{} {} {}".format(
                    left_context, m["mention"], right_context
                ).strip()
                m["query_context_orig"] = "{} {} {}".format(
                    left_context_orig, m["mention_orig"], right_context_orig
                ).strip()

                truncated_left_context = " ".join(left_context.split(" ")[-25:])
                truncated_right_context = " ".join(right_context.split(" ")[:25])
                m["query_truncated_25_context"] = "{} {} {}".format(
                    truncated_left_context, m["mention"], truncated_right_context
                ).strip()

                truncated_left_context = " ".join(left_context.split(" ")[-10:])
                truncated_right_context = " ".join(right_context.split(" ")[:10])
                m["query_truncated_10_context"] = "{} {} {}".format(
                    truncated_left_context, m["mention"], truncated_right_context
                ).strip()

                m["dataset_name"] = ds_name
                m["doc_name"] = doc_name

                sent_id, start, end = (
                    m["conll_m"]["sent_id"],
                    m["conll_m"]["start"],
                    m["conll_m"]["end"],
                )
                prev_sent_id = sent_id - 1
                next_sent_id = sent_id + 1

                sent_orig = " ".join(sentences[sent_id]).strip()
                m["left_query_sent_context_orig"] = " ".join(sentences[sent_id][:start])
                m["right_query_sent_context_orig"] = " ".join(sentences[sent_id][end:])
                sent = solr_escape(sent_orig)

                # try:
                #     context_parts_lower = '{} {} {}'.format(m['left_query_sent_context_orig'], m['mention_orig'], m['right_query_sent_context_orig']).strip().lower()
                #     context_orig_lower = sent_orig.lower()
                #     assert(context_parts_lower == context_orig_lower)
                # except:
                #     print(context_parts_lower)
                #     print(context_orig_lower)
                #     input("")

                if prev_sent_id > 0:
                    prev_sent_orig = " ".join(sentences[prev_sent_id])
                    prev_sent = solr_escape(prev_sent_orig)
                else:
                    prev_sent_orig = None
                    prev_sent = None

                if next_sent_id < len(sentences):
                    next_sent_orig = " ".join(sentences[next_sent_id])
                    next_sent = solr_escape(next_sent_orig)
                else:
                    next_sent_orig = None
                    next_sent = None

                m["sent_context"] = (prev_sent, sent, next_sent)
                m["sent_context_orig"] = (prev_sent_orig, sent_orig, next_sent_orig)
                # m['sent_context_prev'] = get_sent_context(m, 'sent_context_prev')
                # m['sent_context_next'] = get_sent_context(m, 'sent_context_next')
                # m['sent_context_prev_next'] = get_sent_context(m, 'sent_context_prev_next')
                # m['sent_context_curr'] = get_sent_context(m, 'sent_context_curr')

                if gold_wikidata_id is None:
                    invalid += 1
                    continue

                mentions.append(m)
                valid += 1

        print("Invalid: ", invalid)
        print("Valid: ", valid)

        total_invalid += invalid
        total_valid += valid

    return mentions