extraction/evolve_type_system.py (227 lines of code) (raw):

import json import argparse import time import random import numpy as np from evaluate_type_system import fix_and_parse_tags from wikidata_linker_utils.json import load_config from wikidata_linker_utils.type_collection import TypeCollection from wikidata_linker_utils.progressbar import get_progress_bar from wikidata_linker_utils.wikipedia import induce_wikipedia_prefix from os.path import realpath, dirname, join, exists from wikidata_linker_utils.fast_disambiguate import ( beam_project, cem_project, ga_project ) SCRIPT_DIR = dirname(realpath(__file__)) def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument("config", type=str) parser.add_argument("out", type=str) parser.add_argument("--relative_to", default=None, type=str) parser.add_argument("--penalty", default=0.0005, type=float) parser.add_argument("--beam_width", default=8, type=float) parser.add_argument("--beam_search_subset", default=2000, type=int) parser.add_argument("--log", default=None, type=str) parser.add_argument("--samples", type=int, default=1000) parser.add_argument("--ngen", type=int, default=40) parser.add_argument("--method", type=str, choices=["cem", "greedy", "beam", "ga"], default="greedy") return parser.parse_args(args=args) def load_aucs(): paths = [ "/home/jonathanraiman/en_field_auc_w10_e10.json", "/home/jonathanraiman/en_field_auc_w10_e10-s1234.json", "/home/jonathanraiman/en_field_auc_w5_e5.json", "/home/jonathanraiman/en_field_auc_w5_e5-s1234.json" ] aucs = {} for path in paths: with open(path, "rt") as fin: auc_report = json.load(fin) for report in auc_report: key = (report["qid"], report["relation"]) if key in aucs: aucs[key].append(report["auc"]) else: aucs[key] = [report["auc"]] for key in aucs.keys(): aucs[key] = np.mean(aucs[key]) return aucs def greedy_disambiguate(tags): greedy_correct = 0 total = 0 for dest, other_dest, times_pointed in tags: total += 1 if len(other_dest) == 1 and dest == other_dest[0]: greedy_correct += 1 elif other_dest[np.argmax(times_pointed)] == dest: greedy_correct += 1 return greedy_correct, total def fast_disambiguate(tags, all_classifications): correct = 0 total = 0 for dest, other_dest, times_pointed in tags: total += 1 if len(other_dest) == 1 and dest == other_dest[0]: correct += 1 else: identities = np.all(all_classifications[other_dest, :] == all_classifications[dest, :], axis=1) matches = other_dest[identities] matches_counts = times_pointed[identities] if len(matches) == 1 and matches[0] == dest: correct += 1 elif matches[np.argmax(matches_counts)] == dest: correct += 1 return correct, total def get_prefix(config): return config.prefix or induce_wikipedia_prefix(config.wiki) MAX_PICKS = 400.0 def rollout(cached_satisfy, key2row, tags, aucs, ids, sample, penalty, greedy_correct): mean_auc = 0.0 sample_sum = sample.sum() if sample_sum == 0: total = len(tags) return (greedy_correct / total, greedy_correct / total) if sample_sum > MAX_PICKS: return 0.0, 0.0 all_classifications = None if sample_sum > 0: all_classifications = np.zeros((len(ids), int(sample_sum)), dtype=np.bool) col = 0 for picked, (key, auc) in zip(sample, aucs): if picked: all_classifications[:, col] = cached_satisfy[key2row[key]] col += 1 mean_auc += auc mean_auc = mean_auc / sample_sum correct, total = fast_disambiguate(tags, all_classifications) # here's the benefit of using types: improvement = correct - greedy_correct # penalty for using unreliable types: objective = ( (greedy_correct + improvement * mean_auc) / total - # number of items is penalized sample_sum * penalty ) return objective, correct / total def get_cached_satisfy(collection, aucs, ids, mmap=False): path = join(SCRIPT_DIR, "cached_satisfy.npy") if not exists(path): cached_satisfy = np.zeros((len(aucs), len(ids)), dtype=np.bool) for row, (qid, relation_name) in get_progress_bar("satisfy", item="types")(enumerate(sorted(aucs.keys()))): cached_satisfy[row, :] = collection.satisfy([relation_name], [collection.name2index[qid]])[ids] collection._satisfy_cache.clear() np.save(path, cached_satisfy) if mmap: del cached_satisfy cached_satisfy = np.load(path, mmap_mode="r") else: if mmap: cached_satisfy = np.load(path, mmap_mode="r") else: cached_satisfy = np.load(path) return cached_satisfy def main(): args = parse_args() config = load_config( args.config, ["wiki", "language_path", "wikidata", "redirections", "classification"], defaults={ "num_names_to_load": 0, "prefix": None, "sample_size": 100, "wiki": None, "fix_links": False, "min_count": 0, "min_percent": 0.0 }, relative_to=args.relative_to ) if config.wiki is None: raise ValueError("must provide path to 'wiki' in config.") prefix = get_prefix(config) collection = TypeCollection( config.wikidata, num_names_to_load=config.num_names_to_load, prefix=prefix, verbose=True ) collection.load_blacklist(join(SCRIPT_DIR, "blacklist.json")) fname = config.wiki test_tags = fix_and_parse_tags(config, collection, config.sample_size) aucs = load_aucs() ids = sorted(set([idx for doc_tags in test_tags for _, tag in doc_tags if tag is not None for idx in tag[2] if len(tag[2]) > 1])) id2pos = {idx: k for k, idx in enumerate(ids)} # use reduced identity system: remapped_tags = [] for doc_tags in test_tags: for text, tag in doc_tags: if tag is not None: remapped_tags.append( (id2pos[tag[1]] if len(tag[2]) > 1 else tag[1], np.array([id2pos[idx] for idx in tag[2]]) if len(tag[2]) > 1 else tag[2], tag[3])) test_tags = remapped_tags aucs = {key: value for key, value in aucs.items() if value > 0.5} print("%d relations to pick from with %d ids." % (len(aucs), len(ids)), flush=True) cached_satisfy = get_cached_satisfy(collection, aucs, ids, mmap=args.method=="greedy") del collection key2row = {key: k for k, key in enumerate(sorted(aucs.keys()))} if args.method == "greedy": picks, _ = beam_project( cached_satisfy, key2row, remapped_tags, aucs, ids, beam_width=1, penalty=args.penalty, log=args.log ) elif args.method == "beam": picks, _ = beam_project( cached_satisfy, key2row, remapped_tags, aucs, ids, beam_width=args.beam_width, penalty=args.penalty, log=args.log ) elif args.method == "cem": picks, _ = cem_project( cached_satisfy, key2row, remapped_tags, aucs, ids, n_samples=args.samples, penalty=args.penalty, log=args.log ) elif args.method == "ga": picks, _ = ga_project( cached_satisfy, key2row, remapped_tags, aucs, ids, ngen=args.ngen, n_samples=args.samples, penalty=args.penalty, log=args.log ) else: raise ValueError("unknown method %r." % (args.method,)) with open(args.out, "wt") as fout: json.dump(picks, fout) if __name__ == "__main__": main()