in blink/candidate_retrieval/utils.py [0:0]
def get_list_of_mentions(dev_datasets):
mentions = []
total_invalid = 0
total_valid = 0
for ds_name, dataset in dev_datasets:
invalid = 0
valid = 0
print("Processing dataset:", ds_name)
for doc_name, content in dataset.items():
sentences = content[0]["conll_doc"]["sentences"]
for m in content:
gold_wikidata_id = m["gold_wikidata_id"]
left_context, right_context = m["context"]
m["mention_orig"] = m["mention"]
m["mention"] = solr_escape(m["mention"])
if left_context != "EMPTYCTXT":
left_context_orig = left_context
left_context = solr_escape(left_context)
else:
left_context = ""
if right_context != "EMPTYCTXT":
right_context_orig = right_context
right_context = solr_escape(right_context)
else:
right_context = ""
m["left_context_orig"] = left_context_orig
m["right_context_orig"] = right_context_orig
m["query_context"] = "{} {} {}".format(
left_context, m["mention"], right_context
).strip()
m["query_context_orig"] = "{} {} {}".format(
left_context_orig, m["mention_orig"], right_context_orig
).strip()
truncated_left_context = " ".join(left_context.split(" ")[-25:])
truncated_right_context = " ".join(right_context.split(" ")[:25])
m["query_truncated_25_context"] = "{} {} {}".format(
truncated_left_context, m["mention"], truncated_right_context
).strip()
truncated_left_context = " ".join(left_context.split(" ")[-10:])
truncated_right_context = " ".join(right_context.split(" ")[:10])
m["query_truncated_10_context"] = "{} {} {}".format(
truncated_left_context, m["mention"], truncated_right_context
).strip()
m["dataset_name"] = ds_name
m["doc_name"] = doc_name
sent_id, start, end = (
m["conll_m"]["sent_id"],
m["conll_m"]["start"],
m["conll_m"]["end"],
)
prev_sent_id = sent_id - 1
next_sent_id = sent_id + 1
sent_orig = " ".join(sentences[sent_id]).strip()
m["left_query_sent_context_orig"] = " ".join(sentences[sent_id][:start])
m["right_query_sent_context_orig"] = " ".join(sentences[sent_id][end:])
sent = solr_escape(sent_orig)
# try:
# context_parts_lower = '{} {} {}'.format(m['left_query_sent_context_orig'], m['mention_orig'], m['right_query_sent_context_orig']).strip().lower()
# context_orig_lower = sent_orig.lower()
# assert(context_parts_lower == context_orig_lower)
# except:
# print(context_parts_lower)
# print(context_orig_lower)
# input("")
if prev_sent_id > 0:
prev_sent_orig = " ".join(sentences[prev_sent_id])
prev_sent = solr_escape(prev_sent_orig)
else:
prev_sent_orig = None
prev_sent = None
if next_sent_id < len(sentences):
next_sent_orig = " ".join(sentences[next_sent_id])
next_sent = solr_escape(next_sent_orig)
else:
next_sent_orig = None
next_sent = None
m["sent_context"] = (prev_sent, sent, next_sent)
m["sent_context_orig"] = (prev_sent_orig, sent_orig, next_sent_orig)
# m['sent_context_prev'] = get_sent_context(m, 'sent_context_prev')
# m['sent_context_next'] = get_sent_context(m, 'sent_context_next')
# m['sent_context_prev_next'] = get_sent_context(m, 'sent_context_prev_next')
# m['sent_context_curr'] = get_sent_context(m, 'sent_context_curr')
if gold_wikidata_id is None:
invalid += 1
continue
mentions.append(m)
valid += 1
print("Invalid: ", invalid)
print("Valid: ", valid)
total_invalid += invalid
total_valid += valid
return mentions