in nmt/train.py [0:0]
def run_external_eval(infer_model,
infer_sess,
model_dir,
hparams,
summary_writer,
save_best_dev=True,
use_test_set=True,
avg_ckpts=False,
dev_infer_iterator_feed_dict=None,
test_infer_iterator_feed_dict=None):
"""Compute external evaluation for both dev / test.
Computes development and testing external evaluation (e.g. bleu, rouge) for
given model.
Args:
infer_model: Inference model for which to compute perplexities.
infer_sess: Inference TensorFlow session.
model_dir: Directory from which to load inference model from.
hparams: Model hyper-parameters.
summary_writer: Summary writer for logging metrics to TensorBoard.
use_test_set: Computes testing external evaluation if true; does not
otherwise. Note that the development external evaluation is always
computed regardless of value of this parameter.
dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
Can be used to pass in additional inputs necessary for running the
development external evaluation.
test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
Can be used to pass in additional inputs necessary for running the
testing external evaluation.
Returns:
Triple containing development scores, testing scores and the TensorFlow
Variable for the global step number, in this order.
"""
if dev_infer_iterator_feed_dict is None:
dev_infer_iterator_feed_dict = {}
if test_infer_iterator_feed_dict is None:
test_infer_iterator_feed_dict = {}
with infer_model.graph.as_default():
loaded_infer_model, global_step = model_helper.create_or_load_model(
infer_model.model, model_dir, infer_sess, "infer")
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
dev_infer_iterator_feed_dict[
infer_model.src_placeholder] = inference.load_data(dev_src_file)
dev_infer_iterator_feed_dict[
infer_model.batch_size_placeholder] = hparams.infer_batch_size
dev_scores = _external_eval(
loaded_infer_model,
global_step,
infer_sess,
hparams,
infer_model.iterator,
dev_infer_iterator_feed_dict,
dev_tgt_file,
"dev",
summary_writer,
save_on_best=save_best_dev,
avg_ckpts=avg_ckpts)
test_scores = None
if use_test_set and hparams.test_prefix:
test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
test_infer_iterator_feed_dict[
infer_model.src_placeholder] = inference.load_data(test_src_file)
test_infer_iterator_feed_dict[
infer_model.batch_size_placeholder] = hparams.infer_batch_size
test_scores = _external_eval(
loaded_infer_model,
global_step,
infer_sess,
hparams,
infer_model.iterator,
test_infer_iterator_feed_dict,
test_tgt_file,
"test",
summary_writer,
save_on_best=False,
avg_ckpts=avg_ckpts)
return dev_scores, test_scores, global_step