learning/evaluate_learnability.py (251 lines of code) (raw):

""" Obtain a learnability score for each type axis. Trains a binary classifier for each type and gets its AUC. Usage ----- ``` python3 evaluate_learnability.py sample_data.tsv --out report.json --wikidata /path/to/wikidata ``` """ import json import time import argparse from os.path import dirname, realpath, join SCRIPT_DIR = dirname(realpath(__file__)) import numpy as np import tensorflow as tf from sklearn import metrics from collections import Counter from wikidata_linker_utils.type_collection import TypeCollection, offset_values_mask import wikidata_linker_utils.wikidata_properties as wprop from wikidata_linker_utils.progressbar import get_progress_bar from generator import prefetch_generator def learnability(collection, lines, mask, truth_tables, qids, id2pos, epochs=5, batch_size=128, max_dataset_size=-1, max_vocab_size=10000, hidden_sizes=None, lr=0.001, window_size=5, input_size=5, keep_prob=0.5, verbose=True): if hidden_sizes is None: hidden_sizes = [] tf.reset_default_graph() dset = list(get_windows(lines, mask, window_size, truth_tables, lambda x: id2pos[x])) if max_dataset_size > 0: dset = dset[:max_dataset_size] pos_num = np.zeros(len(qids)) for _, labels in dset: pos_num += labels neg_num = np.ones(len(qids)) * len(dset) - pos_num pos_weight = (pos_num / (pos_num + neg_num))[None, :] vocab = ["<UNK>"] + [w for w, _ in Counter(lines[:, 0]).most_common(max_vocab_size)] inv_vocab = {w: k for k, w in enumerate(vocab)} with tf.device("gpu"): W = tf.get_variable( "W", shape=[len(vocab), input_size], dtype=tf.float32, initializer=tf.random_normal_initializer() ) indices = tf.placeholder(tf.int32, [None, window_size*2], name="indices") labels = tf.placeholder(tf.bool, [None, len(qids)], name="label") keep_prob_pholder = tf.placeholder_with_default(keep_prob, []) lookup = tf.reshape(tf.nn.embedding_lookup( W, indices ), [tf.shape(indices)[0], input_size * window_size*2]) lookup = tf.nn.dropout(lookup, keep_prob_pholder) hidden = lookup for layer_idx, hidden_size in enumerate(hidden_sizes): hidden = tf.contrib.layers.fully_connected( hidden, num_outputs=hidden_size, scope="FC%d" % (layer_idx,) ) out = tf.contrib.layers.fully_connected( hidden, num_outputs=len(qids), activation_fn=None) cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=out, labels=tf.cast(labels, tf.float32)) cost = tf.where(tf.is_finite(cost), cost, tf.zeros_like(cost)) cost_mean = tf.reduce_mean( (tf.cast(labels, tf.float32) * 1.0 / (pos_weight)) * cost + (tf.cast(tf.logical_not(labels), tf.float32) * 1.0 / (1.0 - pos_weight)) * cost ) cost_sum = tf.reduce_sum(cost) size = tf.shape(indices)[0] noop = tf.no_op() correct = tf.reduce_sum(tf.cast(tf.equal(tf.greater_equal(out, 0), labels), tf.int32), 0) out_activated = tf.sigmoid(out) train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost_mean) session = tf.InteractiveSession() session.run(tf.global_variables_initializer()) def accuracy(dataset, batch_size, train): epoch_correct = np.zeros(len(qids)) epoch_nll = 0.0 epoch_total = np.zeros(len(qids)) op = train_op if train else noop all_labels = [] all_preds = [] for i in get_progress_bar("train" if train else "dev", item="batches")(range(0, len(dataset), batch_size)): batch_labels = [label for _, label in dataset[i:i+batch_size]] csum, corr, num_examples, preds, _ = session.run([cost_sum, correct, size, out_activated, op], feed_dict={ indices: [[inv_vocab.get(w, 0) for w in window] for window, _ in dataset[i:i+batch_size]], labels: batch_labels, keep_prob_pholder: keep_prob if train else 1.0 }) epoch_correct += corr epoch_nll += csum epoch_total += num_examples all_labels.extend(batch_labels) all_preds.append(preds) return (epoch_nll, epoch_correct, epoch_total, np.vstack(all_preds), np.vstack(all_labels)) dataset_indices = np.arange(len(dset)) train_indices = dataset_indices[:int(0.8 * len(dset))] dev_indices = dataset_indices[int(0.8 * len(dset)):] train_dataset = [dset[idx] for idx in train_indices] dev_dataset = [dset[idx] for idx in dev_indices] learnability = [] for epoch in range(epochs): t0 = time.time() train_epoch_nll, train_epoch_correct, train_epoch_total, _, _ = accuracy(train_dataset, batch_size, train=True) t1 = time.time() if verbose: print("epoch %d train: %.3f%% in %.3fs" % ( epoch, 100.0 * train_epoch_correct.sum() / train_epoch_total.sum(), t1 - t0),) t0 = time.time() dev_epoch_nll, dev_epoch_correct, dev_epoch_total, pred, y = accuracy(dev_dataset, batch_size, train=False) t1 = time.time() learnability = [] for qidx in range(len(qids)): try: fpr, tpr, thresholds = metrics.roc_curve(y[:,qidx], pred[:,qidx], pos_label=1) auc = metrics.auc(fpr, tpr) if not np.isnan(auc): average_precision_score = metrics.average_precision_score(y[:,qidx], pred[:,qidx]) learnability.append((qids[qidx], auc, average_precision_score, 100.0 * dev_epoch_correct[qidx] / dev_epoch_total[qidx], int(pos_num[qidx]), int(neg_num[qidx]))) except ValueError: continue if verbose: learnability = sorted(learnability, key=lambda x: x[1], reverse=True) print("epoch %d dev: %.3fs" % (epoch, t1-t0)) for qid, auc, average_precision_score, acc, pos, neg in learnability: print(" %r AUC: %.3f, APS: %.3f, %.3f%% positive: %d, negative: %d" % ( collection.ids[qid], auc, average_precision_score, acc, pos, neg)) print("") return learnability def generate_training_data(collection, path): with open(path, "rt") as fin: lines = [row.split("\t")[:2] for row in fin.read().splitlines()] lines_arr = np.zeros((len(lines), 2), dtype=np.object) mask = np.zeros(len(lines), dtype=np.bool) for i, l in enumerate(lines): lines_arr[i, 0] = l[0] if len(l) > 1: lines_arr[i, 1] = collection.name2index[l[1]] mask[i] = True return lines_arr, mask def get_proposal_sets(collection, article_ids, seed): np.random.seed(seed) relation = collection.relation(wprop.CATEGORY_LINK) relation_mask = offset_values_mask(relation.values, relation.offsets, article_ids) counts = np.bincount(relation.values[relation_mask]) is_fp = collection.relation(wprop.FIXED_POINTS).edges() > 0 is_fp = is_fp[:counts.shape[0]] counts = counts * is_fp topfields_fp = np.argsort(counts)[::-1][:(counts > 0).sum()] relation = collection.relation(wprop.INSTANCE_OF) relation_mask = offset_values_mask(relation.values, relation.offsets, article_ids) counts = np.bincount(relation.values[relation_mask]) topfields_instance_of = np.argsort(counts)[::-1][:(counts > 0).sum()] np.random.shuffle(topfields_instance_of) np.random.shuffle(topfields_fp) return [(topfields_instance_of, wprop.INSTANCE_OF), (topfields_fp, wprop.CATEGORY_LINK)] def build_truth_tables(collection, lines, qids, relation_name): truth_tables = [] all_ids = list(sorted(set(lines[:, 1]))) id2pos = {idx: pos for pos, idx in enumerate(all_ids)} for qid in qids: truth_tables.append(collection.satisfy([relation_name], [qid])[all_ids]) collection.reset_cache() truth_tables = np.stack(truth_tables, axis=1) qid_sums = truth_tables.sum(axis=0) kept_qids = [] kept_dims = [] for i, (qid, qid_sum) in enumerate(zip(qids, qid_sums)): if qid_sum != 0 and qid_sum != truth_tables.shape[0]: kept_qids.append(qid) kept_dims.append(i) truth_tables = truth_tables[:, kept_dims] return truth_tables, kept_qids, id2pos def get_windows(lines, mask, window, truth_table, id_mapper): for i in np.where(mask)[0]: if i >= window and i < len(lines) - window: yield (lines[max(0, i - window):i + window, 0], truth_table[id_mapper(lines[i, 1])]) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_epochs", type=int, default=2) parser.add_argument("--max_vocab_size", type=int, default=10000) parser.add_argument("--simultaneous_fields", type=int, default=512) parser.add_argument("--window_size", type=int, default=5) parser.add_argument("--input_size", type=int, default=5) parser.add_argument("--wikidata", type=str, required=True) parser.add_argument("--out", type=str, required=True) return parser.parse_args() def generate_truth_tables(collection, lines_arr, proposal_sets, simultaneous_fields): for topfields, relation_name in proposal_sets: for i in range(0, len(topfields), simultaneous_fields): truth_tables, qids, id2pos = build_truth_tables( collection, lines_arr, qids=topfields[i:i+simultaneous_fields], relation_name=relation_name) yield (topfields[i:i+simultaneous_fields], relation_name, truth_tables, qids, id2pos) def main(): args = parse_args() collection = TypeCollection(args.wikidata, num_names_to_load=0) collection.load_blacklist(join(dirname(SCRIPT_DIR), "extraction", "blacklist.json")) lines_arr, mask = generate_training_data(collection, args.dataset) article_ids = np.array(list(set(lines_arr[:, 1])), dtype=np.int32) proposal_sets = get_proposal_sets(collection, article_ids, args.seed) report = [] total = sum(len(topfields) for topfields, _ in proposal_sets) seen = 0 t0 = time.time() data_source = generate_truth_tables(collection, lines_arr, proposal_sets, args.simultaneous_fields) for topfields, relation_name, truth_tables, qids, id2pos in prefetch_generator(data_source): # for each of these properties and given relation # construct the truth table for each item and discover # their 'learnability': seen += len(topfields) field_auc_scores = learnability( collection, lines_arr, mask, qids=qids, truth_tables=truth_tables, id2pos=id2pos, batch_size=args.batch_size, epochs=args.max_epochs, input_size=args.input_size, window_size=args.window_size, max_vocab_size=args.max_vocab_size, verbose=True) for qid, auc, average_precision_score, correct, pos, neg in field_auc_scores: report.append( { "qid": collection.ids[qid], "auc": auc, "average_precision_score": average_precision_score, "correct": correct, "relation": relation_name, "positive": pos, "negative": neg } ) with open(args.out, "wt") as fout: json.dump(report, fout) t1 = time.time() speed = seen / (t1 - t0) print("AUC obtained for %d / %d items (%.3f items/s)" % (seen, total, speed), flush=True) if __name__ == "__main__": main()