extraction/project_graph.py (158 lines of code) (raw):
import argparse
import sys
import json
import time
import traceback
from os import makedirs
from os.path import join, dirname, realpath
from wikidata_linker_utils.repl import (
enter_or_quit, reload_module,
ALLOWED_RUNTIME_ERRORS,
ALLOWED_IMPORT_ERRORS
)
from wikidata_linker_utils.logic import logical_ors
from wikidata_linker_utils.type_collection import TypeCollection
import wikidata_linker_utils.wikidata_properties as wprop
import numpy as np
SCRIPT_DIR = dirname(realpath(__file__))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('wikidata', type=str,
help="Location of wikidata properties.")
parser.add_argument('classifiers', type=str, nargs="+",
help="Filename(s) for Python script that classifies entities.")
parser.add_argument('--export_classification', type=str, nargs="+",
default=None,
help="Location to save the result of the entity classification.")
parser.add_argument('--num_names_to_load', type=int, default=20000000,
help="Number of names to load from disk to accelerate reporting.")
parser.add_argument('--language_path', type=str, default=None,
help="Location of a language-wikipedia specific information set to "
"provide language/wikipedia specific metrics.")
parser.add_argument('--interactive', action="store_true", default=True,
help="Operate in a REPL. Reload scripts on errors or on user prompt.")
parser.add_argument('--nointeractive', action="store_false",
dest="interactive", help="Run classification without REPL.")
parser.add_argument('--use-cache', action="store_true",
dest="use_cache", help="store satisfies in cache.")
parser.add_argument('--nouse-cache', action="store_false",
dest="use_cache", help="not store satisfies in cache.")
return parser.parse_args()
def get_other_class(classification):
if len(classification) == 0:
return None
return np.logical_not(logical_ors(
list(classification.values())
))
def export_classification(classification, path):
classes = sorted(list(classification.keys()))
if len(classes) == 0:
return
makedirs(path, exist_ok=True)
num_items = classification[classes[0]].shape[0]
classid = np.zeros(num_items, dtype=np.int32)
selected = np.zeros(num_items, dtype=np.bool)
for index, classname in enumerate(classes):
truth_table = classification[classname]
selected = selected | truth_table
classid = np.maximum(classid, truth_table.astype(np.int32) * index)
other = np.logical_not(selected)
if other.sum() > 0:
classes_with_other = classes + ["other"]
classid = np.maximum(classid, other.astype(np.int32) * len(classes))
else:
classes_with_other = classes
with open(join(path, "classes.txt"), "wt") as fout:
for classname in classes_with_other:
fout.write(classname + "\n")
np.save(join(path, "classification.npy"), classid)
def main():
args = parse_args()
should_export = args.export_classification is not None
if should_export and len(args.export_classification) != len(args.classifiers):
raise ValueError("Must have as many export filenames as classifiers.")
collection = TypeCollection(
args.wikidata,
num_names_to_load=args.num_names_to_load,
language_path=args.language_path,
cache=args.use_cache
)
if args.interactive:
alert_failure = enter_or_quit
else:
alert_failure = lambda: sys.exit(1)
while True:
try:
collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json"))
except (ValueError,) as e:
print("Issue reading blacklist, please fix.")
print(str(e))
alert_failure()
continue
classifications = []
for class_idx, classifier_fname in enumerate(args.classifiers):
while True:
try:
classifier = reload_module(classifier_fname)
except ALLOWED_IMPORT_ERRORS as e:
print("issue reading %r, please fix." % (classifier_fname,))
print(str(e))
traceback.print_exc(file=sys.stdout)
alert_failure()
continue
try:
t0 = time.time()
classification = classifier.classify(collection)
classifications.append(classification)
if class_idx == len(args.classifiers) - 1:
collection.reset_cache()
t1 = time.time()
print("classification took %.3fs" % (t1 - t0,))
except ALLOWED_RUNTIME_ERRORS as e:
print("issue running %r, please fix." % (classifier_fname,))
print(str(e))
traceback.print_exc(file=sys.stdout)
alert_failure()
continue
break
try:
# show cardinality for each truth table:
if args.interactive:
mega_other_class = None
for classification in classifications:
for classname in sorted(classification.keys()):
print("%r: %d members" % (classname, int(classification[classname].sum())))
print("")
summary = {}
for classname, truth_table in classification.items():
(members,) = np.where(truth_table)
summary[classname] = [collection.get_name(int(member)) for member in members[:20]]
print(json.dumps(summary, indent=4))
other_class = get_other_class(classification)
if other_class.sum() > 0:
# there are missing items:
to_report = (
classifier.class_report if hasattr(classifier, "class_report") else
[wprop.SUBCLASS_OF, wprop.INSTANCE_OF, wprop.OCCUPATION, wprop.CATEGORY_LINK]
)
collection.class_report(to_report, other_class, name="Other")
if mega_other_class is None:
mega_other_class = other_class
else:
mega_other_class = np.logical_and(mega_other_class, other_class)
if len(classifications) > 1:
if mega_other_class.sum() > 0:
# there are missing items:
to_report = [wprop.SUBCLASS_OF, wprop.INSTANCE_OF, wprop.OCCUPATION, wprop.CATEGORY_LINK]
collection.class_report(to_report, mega_other_class, name="Other-combined")
if should_export:
assert(len(classifications) == len(args.export_classification)), (
"classification outputs missing for export."
)
for classification, savename in zip(classifications, args.export_classification):
export_classification(classification, savename)
except KeyboardInterrupt as e:
pass
if args.interactive:
enter_or_quit()
else:
break
if __name__ == "__main__":
main()