in research/a2n/train.py [0:0]
def evaluate():
"""Run evaluation on dev or test data."""
add_inverse_edge = FLAGS.model in \
["source_rel_attention", "source_path_attention"]
if FLAGS.clueweb_data:
train_graph = clueweb_text_graph.CWTextGraph(
text_kg_file=FLAGS.clueweb_data,
embeddings_file=FLAGS.clueweb_embeddings,
sentence_vocab_file=FLAGS.clueweb_sentences,
skip_new=True,
kg_file=FLAGS.kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
subsample=FLAGS.subsample_text_rels
)
elif FLAGS.text_kg_file:
train_graph = text_graph.TextGraph(
text_kg_file=FLAGS.text_kg_file,
skip_new=True,
max_text_len=FLAGS.max_text_len,
max_vocab_size=FLAGS.max_vocab_size,
min_word_freq=FLAGS.min_word_freq,
kg_file=FLAGS.kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
max_path_length=FLAGS.max_path_length
)
else:
train_graph = graph.Graph(
kg_file=FLAGS.kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
max_path_length=FLAGS.max_path_length
)
# train_graph, _ = read_graph_data(
# kg_file=FLAGS.kg_file,
# add_reverse_graph=(FLAGS.model != "source_rel_attention"),
# add_inverse_edge=(FLAGS.model == "source_rel_attention"),
# mode="train", num_epochs=FLAGS.num_epochs, batchsize=FLAGS.batchsize,
# max_neighbors=FLAGS.max_neighbors,
# max_negatives=FLAGS.max_negatives
# )
val_graph = None
if FLAGS.dev_kg_file:
val_graph, eval_data = read_graph_data(
kg_file=FLAGS.dev_kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
# add_reverse_graph=False,
# add_inverse_edge=False,
mode="dev", num_epochs=1, batchsize=FLAGS.test_batchsize,
max_neighbors=FLAGS.max_neighbors,
max_negatives=FLAGS.max_negatives, train_graph=train_graph,
text_kg_file=FLAGS.text_kg_file
)
if FLAGS.test_kg_file:
_, eval_data = read_graph_data(
kg_file=FLAGS.test_kg_file,
add_reverse_graph=not add_inverse_edge,
add_inverse_edge=add_inverse_edge,
# add_reverse_graph=False,
# add_inverse_edge=False,
mode="test", num_epochs=1, batchsize=FLAGS.test_batchsize,
max_neighbors=FLAGS.max_neighbors,
max_negatives=None, train_graph=train_graph,
text_kg_file=FLAGS.text_kg_file,
val_graph=val_graph
)
if not FLAGS.dev_kg_file and not FLAGS.test_kg_file:
raise ValueError("Evalution without a dev or test file!")
iterator = eval_data.dataset.make_initializable_iterator()
candidate_scores, candidates, labels, model, is_train_ph, inputs = \
create_model(train_graph, iterator)
# Create eval metrics
# if FLAGS.dev_kg_file:
batch_rr = metrics.mrr(candidate_scores, candidates, labels)
mrr, mrr_update = tf.metrics.mean(batch_rr)
mrr_summary = tf.summary.scalar("MRR", mrr)
all_hits, all_hits_update, all_hits_summaries = [], [], []
for k in [1, 3, 10]:
batch_hits = metrics.hits_at_k(candidate_scores, candidates, labels, k=k)
hits, hits_update = tf.metrics.mean(batch_hits)
hits_summary = tf.summary.scalar("Hits_at_%d" % k, hits)
all_hits.append(hits)
all_hits_update.append(hits_update)
all_hits_summaries.append(hits_summary)
hits = tf.group(*all_hits)
hits_update = tf.group(*all_hits_update)
global_step = tf.Variable(0, name="global_step", trainable=False)
current_step = tf.Variable(0, name="current_step", trainable=False,
collections=[tf.GraphKeys.LOCAL_VARIABLES])
incr_current_step = tf.assign_add(current_step, 1)
reset_current_step = tf.assign(current_step, 0)
slim.get_or_create_global_step(graph=tf.get_default_graph())
# best_hits = tf.Variable(0., trainable=False)
# best_step = tf.Variable(0, trainable=False)
# with tf.control_dependencies([hits]):
# update_best_hits = tf.cond(tf.greater(hits, best_hits),
# lambda: tf.assign(best_hits, hits),
# lambda: 0.)
# update_best_step = tf.cond(tf.greater(hits, best_hits),
# lambda: tf.assign(best_step, global_step),
# lambda: 0)
# best_hits_summary = tf.summary.scalar("Best Hits@10", best_hits)
# best_step_summary = tf.summary.scalar("Best Step", best_step)
nexamples = eval_data.data_graph.tuple_store.shape[0]
if eval_data.data_graph.add_reverse_graph:
nexamples *= 2
num_batches = math.ceil(nexamples / float(FLAGS.test_batchsize))
local_init_op = tf.local_variables_initializer()
if FLAGS.analyze:
entity_names = utils.read_entity_name_mapping(FLAGS.entity_names_file)
session = tf.Session()
# summary_writer = tf.summary.FileWriter(FLAGS.output_dir, session.graph)
init_op = tf.global_variables_initializer()
session.run(init_op)
session.run(local_init_op)
saver = tf.train.Saver(tf.trainable_variables())
ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step
attention_probs = model["attention_encoder"].get_from_collection(
"attention_probs"
)
if FLAGS.clueweb_data:
s, nbrs_s, text_nbrs_s, text_nbrs_s_emb, r, candidates, _ = inputs
elif FLAGS.text_kg_file:
s, nbrs_s, text_nbrs_s, r, candidates, _ = inputs
else:
s, nbrs_s, r, candidates, _ = inputs
saver.restore(session, ckpt_path)
session.run(iterator.initializer)
num_attention = 5
nsteps = 0
outf_correct = open(FLAGS.output_dir + "/analyze_correct.txt", "w+")
outf_incorrect = open(
FLAGS.output_dir + "/analyze_incorrect.txt", "w+"
)
ncorrect = 0
analyze_outputs = [candidate_scores, s, nbrs_s, r, candidates, labels,
attention_probs]
if FLAGS.text_kg_file:
analyze_outputs.append(text_nbrs_s)
while True:
try:
analyze_vals = session.run(analyze_outputs, {is_train_ph: False})
if FLAGS.text_kg_file:
cscores, se, nbrs, qr, cands, te, nbr_attention_probs, text_nbrs = \
analyze_vals
else:
cscores, se, nbrs, qr, cands, te, nbr_attention_probs = analyze_vals
# import pdb; pdb.set_trace()
pred_ids = cscores.argmax(1)
for i in range(se.shape[0]):
sname = train_graph.inverse_entity_vocab[se[i]]
if sname in entity_names:
sname = entity_names[sname]
rname = train_graph.inverse_relation_vocab[qr[i]]
pred_target = cands[i, pred_ids[i]]
pred_name = train_graph.inverse_entity_vocab[pred_target]
if pred_name in entity_names:
pred_name = entity_names[pred_name]
tname = train_graph.inverse_entity_vocab[te[i][0]]
if tname in entity_names:
tname = entity_names[tname]
if te[i][0] == pred_target:
outf = outf_correct
ncorrect += 1
else:
outf = outf_incorrect
outf.write("\n(%d) %s, %s, ? \t Pred: %s \t Target: %s" %
(nsteps+i+1, sname, rname, pred_name, tname))
top_nbrs_index = np.argsort(nbr_attention_probs[i, :])[::-1]
outf.write("\nTop Nbrs:")
for j in range(num_attention):
nbr_index = top_nbrs_index[j]
if nbr_index < FLAGS.max_neighbors:
nbr_id = nbrs[i, nbr_index, :]
nbr_name = ""
for k in range(0, nbrs.shape[-1], 2):
ent_name = train_graph.inverse_entity_vocab[nbr_id[k+1]]
if ent_name in entity_names:
ent_name = entity_names[ent_name]
rel_name = train_graph.inverse_relation_vocab[nbr_id[k]]
nbr_name += "(%s, %s)" % (rel_name, ent_name)
else:
# Text Relation
text_nbr_ids = text_nbrs[i, nbr_index - FLAGS.max_neighbors, :]
text_nbr_ent = text_nbr_ids[0]
ent_name = train_graph.inverse_entity_vocab[text_nbr_ent]
if ent_name in entity_names:
ent_name = entity_names[ent_name]
rel_name = train_graph.get_relation_text(text_nbr_ids[1:])
nbr_name = "(%s, %s)" % (rel_name, ent_name)
outf.write("\n\t\t %s Prob: %.4f" %
(nbr_name, nbr_attention_probs[i, nbr_index]))
nsteps += se.shape[0]
tf.logging.info("Current hits@1: %.3f", ncorrect * 1.0 / (nsteps))
except tf.errors.OutOfRangeError:
break
outf_correct.close()
outf_incorrect.close()
return
class DataInitHook(tf.train.SessionRunHook):
def after_create_session(self, sess, coord):
sess.run(iterator.initializer)
sess.run(reset_current_step)
if FLAGS.test_only:
ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step
slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=ckpt_path,
logdir=FLAGS.output_dir,
variables_to_restore=tf.trainable_variables() + [global_step],
initial_op=tf.group(local_init_op, iterator.initializer),
# initial_op=iterator.initializer,
num_evals=num_batches,
eval_op=tf.group(mrr_update, hits_update, incr_current_step),
eval_op_feed_dict={is_train_ph: False},
final_op=tf.group(mrr, hits),
final_op_feed_dict={is_train_ph: False},
summary_op=tf.summary.merge([mrr_summary]+ all_hits_summaries),
hooks=[DataInitHook(),
tf.train.LoggingTensorHook(
{"mrr": mrr, "hits": hits, "step": current_step},
every_n_iter=1
)]
)
else:
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.model_path,
logdir=FLAGS.output_dir,
variables_to_restore=tf.trainable_variables() + [global_step],
initial_op=tf.group(local_init_op, iterator.initializer),
# initial_op=iterator.initializer,
num_evals=num_batches,
eval_op=tf.group(mrr_update, hits_update, incr_current_step),
eval_op_feed_dict={is_train_ph: False},
final_op=tf.group(mrr, hits),
final_op_feed_dict={is_train_ph: False},
summary_op=tf.summary.merge([mrr_summary] + all_hits_summaries),
max_number_of_evaluations=None,
eval_interval_secs=60,
hooks=[DataInitHook(),
tf.train.LoggingTensorHook(
{"mrr": mrr, "hits": hits, "step": current_step},
every_n_iter=1
)]
)