extraction/evaluate_type_system.py (390 lines of code) (raw):
import sys
import pickle
import argparse
import requests
import marisa_trie
import traceback
import numpy as np
from os.path import join, dirname, realpath, exists
from os import stat
from collections import Counter
from itertools import product
from wikidata_linker_utils.anchor_filtering import clean_up_trie_source, acceptable_anchor
from wikidata_linker_utils.wikipedia import (
load_wikipedia_docs, induce_wikipedia_prefix, load_redirections, transition_trie_index
)
from wikidata_linker_utils.json import load_config
from wikidata_linker_utils.offset_array import OffsetArray
from wikidata_linker_utils.repl import reload_run_retry, enter_or_quit
from wikidata_linker_utils.progressbar import get_progress_bar
from wikidata_linker_utils.type_collection import TypeCollection, get_name as web_get_name
SCRIPT_DIR = dirname(realpath(__file__))
PROJECT_DIR = dirname(SCRIPT_DIR)
INTERNET = True
def maybe_web_get_name(s):
global INTERNET
if INTERNET:
try:
res = web_get_name(s)
return res
except requests.exceptions.ConnectionError:
INTERNET = False
return s
class OracleClassification(object):
def __init__(self, classes, classification, path):
self.classes = classes
self.classification = classification
self.path = path
self.contains_other = self.classes[-1] == "other"
def classify(self, index):
return self.classification[index]
def load_oracle_classification(path):
with open(join(path, "classes.txt"), "rt") as fin:
classes = fin.read().splitlines()
classification = np.load(join(path, "classification.npy"))
return OracleClassification(classes, classification, path)
def can_disambiguate(oracles, truth, alternatives,
times_pointed, count_threshold,
ignore_other=False, keep_other=False):
ambig = np.ones(len(alternatives), dtype=np.bool)
for oracle in oracles:
truth_pred = oracle.classify(truth)
alt_preds = oracle.classify(alternatives)
if keep_other and oracle.contains_other:
if truth_pred == len(oracle.classes) - 1:
continue
else:
ambig = np.logical_and(
ambig,
np.logical_or(
np.equal(alt_preds, truth_pred),
np.equal(alt_preds, len(oracle.classes) - 1)
)
)
elif ignore_other and oracle.contains_other and np.any(alt_preds == len(oracle.classes) - 1):
continue
else:
ambig = np.logical_and(ambig, np.equal(alt_preds, truth_pred))
# apply type rules to disambiguate:
alternatives_matching_type = alternatives[ambig]
alternatives_matching_type_times_pointed = times_pointed[ambig]
if len(alternatives_matching_type) <= 1:
return alternatives_matching_type, alternatives_matching_type_times_pointed, False
# apply rules for count thresholding:
ordered_times_pointed = np.argsort(alternatives_matching_type_times_pointed)[::-1]
top1count = alternatives_matching_type_times_pointed[ordered_times_pointed[0]]
top2count = alternatives_matching_type_times_pointed[ordered_times_pointed[1]]
if top1count > top2count + count_threshold and alternatives_matching_type[ordered_times_pointed[0]] == truth:
return (
alternatives_matching_type[ordered_times_pointed[0]:ordered_times_pointed[0]+1],
alternatives_matching_type_times_pointed[ordered_times_pointed[0]:ordered_times_pointed[0]+1],
True
)
return alternatives_matching_type, alternatives_matching_type_times_pointed, False
def disambiguate(tags, oracles):
ambiguous = 0
obvious = 0
disambiguated_oracle = 0
disambiguated_with_counts = 0
disambiguated_greedy = 0
disambiguated_with_background = 0
count_threshold = 0
ambiguous_tags = []
obvious_tags = []
non_obvious_tags = []
disambiguated_oracle_ignore_other = 0
disambiguated_oracle_keep_other = 0
for text, tag in tags:
if tag is None:
continue
anchor, dest, other_dest, times_pointed = tag
if len(other_dest) == 1:
obvious += 1
obvious_tags.append((anchor, dest, other_dest, times_pointed))
else:
ambiguous += 1
non_obvious_tags.append((anchor, dest, other_dest, times_pointed))
if other_dest[np.argmax(times_pointed)] == dest:
disambiguated_greedy += 1
matching_tags, times_pointed_subset, used_counts = can_disambiguate(
oracles, dest, other_dest, times_pointed, count_threshold
)
if len(matching_tags) <= 1:
if used_counts:
disambiguated_with_counts += 1
else:
disambiguated_oracle += 1
else:
ambiguous_tags.append(
(anchor, dest, matching_tags, times_pointed_subset)
)
matching_tags, times_pointed_subset, used_counts = can_disambiguate(
oracles, dest, other_dest, times_pointed, count_threshold, ignore_other=True
)
if len(matching_tags) <= 1:
disambiguated_oracle_ignore_other += 1
matching_tags, times_pointed_subset, used_counts = can_disambiguate(
oracles, dest, other_dest, times_pointed, count_threshold, keep_other=True
)
if len(matching_tags) <= 1:
disambiguated_oracle_keep_other += 1
report = {
"ambiguous": ambiguous,
"obvious": obvious,
"disambiguated oracle": disambiguated_oracle,
"disambiguated greedy": disambiguated_greedy,
"disambiguated oracle + counts": disambiguated_oracle + disambiguated_with_counts,
"disambiguated oracle + counts + ignore other": disambiguated_oracle_ignore_other,
"disambiguated oracle + counts + keep other": disambiguated_oracle_keep_other
}
return (report, ambiguous_tags)
def disambiguate_batch(test_tags, train_tags, oracles):
test_tags = test_tags
total_report = {}
ambiguous_tags = []
for tags in get_progress_bar("disambiguating", item="articles")(test_tags):
report, remainder = disambiguate(tags, oracles)
ambiguous_tags.extend(remainder)
for key, value in report.items():
if key not in total_report:
total_report[key] = value
else:
total_report[key] += value
return total_report, ambiguous_tags
def obtain_tags(doc,
wiki_trie,
anchor_trie,
trie_index2indices,
trie_index2indices_counts,
trie_index2indices_transitions,
redirections,
prefix,
collection,
first_names,
min_count,
min_percent):
out_doc = []
for anchor, dest_index in doc.links(wiki_trie, redirections, prefix):
if dest_index is None:
out_doc.append((anchor, None))
continue
anchor_stripped = anchor.strip()
keep = False
if len(anchor_stripped) > 0:
anchor_stripped = clean_up_trie_source(anchor_stripped)
if acceptable_anchor(anchor_stripped, anchor_trie, first_names):
anchor_idx = anchor_trie[anchor_stripped]
all_options = trie_index2indices[anchor_idx]
all_counts = trie_index2indices_counts[anchor_idx]
if len(all_options) > 0:
if trie_index2indices_transitions is not None:
old_dest_index = dest_index
dest_index = transition_trie_index(
anchor_idx, dest_index,
trie_index2indices_transitions,
all_options
)
if dest_index != -1:
new_dest_index = dest_index
keep = True
if keep and (min_count > 0 or min_percent > 0):
dest_count = all_counts[all_options==new_dest_index]
if dest_count < min_count or (dest_count / sum(all_counts)) < min_percent:
keep = False
if keep:
out_doc.append(
(
anchor,
(anchor_stripped, new_dest_index, all_options, all_counts)
)
)
if not keep:
out_doc.append((anchor, None))
return out_doc
def add_boolean(parser, name, default):
parser.add_argument("--%s" % (name,), action="store_true", default=default)
parser.add_argument("--no%s" % (name,), action="store_false", dest=name)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("config")
parser.add_argument("--relative_to", type=str, default=None)
parser.add_argument("--log", type=str, default=None)
add_boolean(parser, "verbose", True)
add_boolean(parser, "interactive", True)
return parser
def parse_args(args=None):
return get_parser().parse_args(args=args)
def summarize_disambiguation(total_report, file=None):
if file is None:
file = sys.stdout
if total_report.get("ambiguous", 0) > 0:
for key, value in sorted(total_report.items(), key=lambda x : x[1]):
if "disambiguated" in key:
print("%.3f%% disambiguated by %s (%d / %d)" % (
100.0 * value / total_report["ambiguous"],
key[len("disambiguated"):].strip(),
value, total_report["ambiguous"]
), file=file
)
print("", file=file)
for key, value in sorted(total_report.items(), key=lambda x : x[1]):
if "disambiguated" in key:
print("%.3f%% disambiguated by %s [including single choice] (%d / %d)" % (
100.0 * (
(value + total_report["obvious"]) /
(total_report["ambiguous"] + total_report["obvious"])
),
key[len("disambiguated"):].strip(),
value + total_report["obvious"],
total_report["ambiguous"] + total_report["obvious"]
), file=file
)
print("", file=file)
def summarize_ambiguities(ambiguous_tags,
oracles,
get_name):
class_ambiguities = {}
for anchor, dest, other_dest, times_pointed in ambiguous_tags:
class_ambig_name = []
for oracle in oracles:
class_ambig_name.append(oracle.classes[oracle.classify(dest)])
class_ambig_name = " and ".join(class_ambig_name)
if class_ambig_name not in class_ambiguities:
class_ambiguities[class_ambig_name] = {
"count": 1,
"examples": [(anchor, dest, other_dest, times_pointed)]
}
else:
class_ambiguities[class_ambig_name]["count"] += 1
class_ambiguities[class_ambig_name]["examples"].append((anchor, dest, other_dest, times_pointed))
print("Ambiguity Report:")
for classname, ambiguity in sorted(class_ambiguities.items(), key=lambda x: x[0]):
print(" %s" % (classname,))
print(" %d ambiguities" % (ambiguity["count"],))
common_bad_anchors = Counter([anc for anc, _, _, _ in ambiguity["examples"]]).most_common(6)
anchor2example = {anc: (dest, other_dest, times_pointed) for anc, dest, other_dest, times_pointed in ambiguity["examples"]}
for bad_anchor, count in common_bad_anchors:
dest, other_dest, times_pointed = anchor2example[bad_anchor]
truth_times_pointed = int(times_pointed[np.equal(other_dest, dest)])
only_alt = [(el, int(times_pointed[k])) for k, el in enumerate(other_dest) if el != dest]
only_alt = sorted(only_alt, key=lambda x: x[1], reverse=True)
print(" %r (%d time%s)" % (bad_anchor, count, 's' if count != 1 else ''))
print(" Actual: %r" % ((get_name(dest), truth_times_pointed),))
print(" Others: %r" % ([(get_name(el), c) for (el, c) in only_alt[:5]]))
print("")
print("")
def get_prefix(config):
return config.prefix or induce_wikipedia_prefix(config.wiki)
def fix_and_parse_tags(config, collection, size):
trie_index2indices = OffsetArray.load(
join(config.language_path, "trie_index2indices"),
compress=True
)
trie_index2indices_counts = OffsetArray(
np.load(join(config.language_path, "trie_index2indices_counts.npy")),
trie_index2indices.offsets
)
if exists(join(config.language_path, "trie_index2indices_transition_values.npy")):
trie_index2indices_transitions = OffsetArray(
np.load(join(config.language_path, "trie_index2indices_transition_values.npy")),
np.load(join(config.language_path, "trie_index2indices_transition_offsets.npy")),
)
else:
trie_index2indices_transitions = None
anchor_trie = marisa_trie.Trie().load(join(config.language_path, "trie.marisa"))
wiki_trie = marisa_trie.RecordTrie('i').load(
join(config.wikidata, "wikititle2wikidata.marisa")
)
prefix = get_prefix(config)
redirections = load_redirections(config.redirections)
docs = load_wikipedia_docs(config.wiki, size)
while True:
try:
collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
except (ValueError,) as e:
print("issue reading blacklist, please fix.")
print(str(e))
enter_or_quit()
continue
break
print("Load first_names")
with open(join(PROJECT_DIR, "data", "first_names.txt"), "rt") as fin:
first_names = set(fin.read().splitlines())
all_tags = []
for doc in get_progress_bar('fixing links', item='article')(docs):
tags = obtain_tags(
doc,
wiki_trie=wiki_trie,
anchor_trie=anchor_trie,
trie_index2indices=trie_index2indices,
trie_index2indices_counts=trie_index2indices_counts,
trie_index2indices_transitions=trie_index2indices_transitions,
redirections=redirections,
prefix=prefix,
first_names=first_names,
collection=collection,
min_count=config.min_count,
min_percent=config.min_percent)
if any(x is not None for _, x in tags):
all_tags.append(tags)
collection.reset_cache()
return all_tags
def main():
args = parse_args()
config = load_config(args.config,
["wiki",
"language_path",
"wikidata",
"redirections",
"classification",
"path"],
defaults={"num_names_to_load": 0,
"prefix": None,
"sample_size": 100,
"wiki": None,
"min_count": 0,
"min_percent": 0.0},
relative_to=args.relative_to)
if config.wiki is None:
raise ValueError("must provide path to 'wiki' in config.")
prefix = get_prefix(config)
print("Load type_collection")
collection = TypeCollection(
config.wikidata,
num_names_to_load=config.num_names_to_load,
prefix=prefix,
verbose=True)
fname = config.wiki
all_tags = fix_and_parse_tags(config, collection, config.sample_size)
test_tags = all_tags[:config.sample_size]
train_tags = all_tags[config.sample_size:]
oracles = [load_oracle_classification(classification)
for classification in config.classification]
def get_name(idx):
if idx < config.num_names_to_load:
if idx in collection.known_names:
return collection.known_names[idx] + " (%s)" % (collection.ids[idx],)
else:
return collection.ids[idx]
else:
return maybe_web_get_name(collection.ids[idx]) + " (%s)" % (collection.ids[idx],)
while True:
total_report, ambiguous_tags = disambiguate_batch(
test_tags, train_tags, oracles)
summarize_disambiguation(total_report)
if args.log is not None:
with open(args.log, "at") as fout:
summarize_disambiguation(total_report, file=fout)
if args.verbose:
try:
summarize_ambiguities(
ambiguous_tags,
oracles,
get_name
)
except KeyboardInterrupt as e:
pass
if args.interactive:
enter_or_quit()
else:
break
if __name__ == "__main__":
main()