in mesh_tensorflow/transformer/utils.py [0:0]
def eval_model(estimator,
vocabulary,
sequence_length,
batch_size,
dataset_split,
model_dir,
eval_dataset_fn,
eval_summary_dir,
eval_checkpoint_step,
eval_with_score=False,
output_eval_examples=True,
eval_dir_suffix=None,
score_with_estimator_fn=score_with_estimator):
"""Eval a Mesh-TF model.
Args:
estimator: an Estimator object or a callable that returns one.
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
sequence_length: a dict from feature-key to integer the (packed)
sequence length, e.g. {"inputs": 512, "targets": 128}. May also be set to
`None` to automatically compute the maximum length of the examples, which
requires `estimator` to be a callable.
batch_size: an integer, global batch size
dataset_split: a string
model_dir: a string, directory with the model.
eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
Must be provided for mode="eval". Should accept the following arguments:
- sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
- vocabulary: Vocabulary instance to use for encoding.
- dataset_split: str, which dataset split to load.
dataset.EvalDataset tuples are namedtuples with the following fields:
- name: string, the task name
- dataset_fn: function which returns a tf.data.Dataset of tokenized and
padded examples. Must not require any arguments and must include the
feature keys 'inputs' and 'targets_pretokenized'.
- postprocess_fn: function which converts original targets to values
that can be processed by a `metric_fn`.
- list_of_metric_fns: list of metric functions with the call signature
`metric_fn(targets, predictions)` which returns a dict mapping
submetric names to scalar values. TensorBoard summaries and other tags
will be written out using the submetric names.
eval_summary_dir: str, path to write TensorBoard events file summaries for
eval. If None, use model_dir/eval_{split}.
eval_checkpoint_step: int, list of ints, or None. If an int or list of ints,
evaluation or inference will be run on the checkpoint files in `model_dir`
whose global steps are closest to the global steps provided. If None and
mode="eval", run eval continuously waiting for new checkpoints via
`tf.train.checkpoints_iterator`.
eval_with_score: bool, whether to evaluate using log likelihood scores of
targets instead of decoded predictions.
output_eval_examples: bool, whether to dump inputs, targets and predictions
of the eval examples in plaintext to eval_summary_dir.
eval_dir_suffix: string, if not None then will appended to the
eval_summary_dir.
score_with_estimator_fn: a function to run scoring with the estimator.
"""
if eval_dataset_fn is None:
raise ValueError("Must provide eval_dataset_fn through gin for eval.")
if sequence_length is None and not callable(estimator):
raise ValueError(
"A callable must be passed for the estimator when automatically "
"computing the sequence length.")
eval_datasets = eval_dataset_fn(
sequence_length=sequence_length,
vocabulary=vocabulary,
dataset_split=dataset_split,
)
valid_eval_datasets = []
for eval_dataset in eval_datasets:
if not eval_dataset.metric_fns:
tf.logging.info("Skipping %s because metric_fns is empty",
eval_dataset.name)
continue
# Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples
valid_eval_datasets.append(transformer_dataset.EvalDataset(*eval_dataset))
eval_datasets = valid_eval_datasets
if not eval_datasets:
tf.logging.info(
"All provided EvalDatasets have metric_fns=[]; eval is not possible.")
return
eval_summary_dir = eval_summary_dir or os.path.join(
model_dir, "{}_eval".format(dataset_split))
if eval_dir_suffix is not None:
eval_summary_dir += "_{}".format(eval_dir_suffix)
summary_writer = tf.summary.FileWriter(eval_summary_dir)
# Pre-load in all of the targets once before entering continuous eval loop
cached_targets = {}
cached_examples = {}
# Need to create a separate graph for loading in original targets
# or else TF will complain that we modified the graph
max_sequence_length = {"inputs": 0, "targets": 0}
tf.logging.info("Caching evaluation examples.")
with tf.Graph().as_default():
for eval_dataset in eval_datasets:
if eval_dataset.metric_fns:
ds = eval_dataset.dataset_fn()
# Create list of postprocessed text targets
inputs = []
targets = []
examples = []
for ex in tfds.as_numpy(ds):
max_sequence_length["inputs"] = max(
max_sequence_length["inputs"], len(ex["inputs"]))
max_sequence_length["targets"] = max(
max_sequence_length["targets"], len(ex["targets"]))
examples.append(ex)
if "inputs_pretokenized" in ex:
inputs.append(ex["inputs_pretokenized"])
if "targets_pretokenized" in ex:
targets_pretokenized = ex["targets_pretokenized"]
if isinstance(targets_pretokenized, bytes):
targets_pretokenized = targets_pretokenized.decode("utf-8")
targets.append(
eval_dataset.postprocess_fn(
targets_pretokenized, example=ex, is_target=True)
)
if output_eval_examples:
targets_filename = os.path.join(
eval_summary_dir,
"{}_targets".format(eval_dataset.name),
)
write_lines_to_file(targets, targets_filename)
inputs_filename = os.path.join(eval_summary_dir,
"{}_inputs".format(eval_dataset.name))
write_lines_to_file(inputs, inputs_filename)
cached_targets[eval_dataset.name] = targets
cached_examples[eval_dataset.name] = examples
if sequence_length is None:
tf.logging.info("Setting sequence lengths to %s", max_sequence_length)
sequence_length = max_sequence_length
estimator = functools.partial(estimator, sequence_length=sequence_length)
elif (sequence_length["inputs"] < max_sequence_length["inputs"] or
sequence_length["targets"] < max_sequence_length["targets"]):
tf.logging.warning(
"Given sequence lengths are insufficient for some evaluation inputs or "
"targets. These sequences will be truncated to fit, likely leading to "
"sub-optimal results. Consider passing `None` for sequence_length to "
"have them be automatically computed.\n Got: %s,\n Max Lengths: %s",
sequence_length, max_sequence_length)
elif (sequence_length["inputs"] > max_sequence_length["inputs"] or
sequence_length["targets"] > max_sequence_length["targets"]):
tf.logging.warning(
"Given sequence lengths are longer than necessary for some evaluation "
"inputs or targets, resulting in wasted computation. Consider passing "
"`None` for sequence_length to have them be automatically computed.\n"
" Got: %s,\n Max Lengths: %s",
sequence_length, max_sequence_length)
if callable(estimator):
estimator = estimator()
input_fn = _get_combined_dataset_input_fn(
eval_datasets, batch_size, sequence_length, check_for_metrics=True)
checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
for checkpoint_path in checkpoint_paths:
tf.logging.info("Checkpoint path %s", checkpoint_path)
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
if eval_with_score:
outputs, _ = score_with_estimator_fn(
estimator, input_fn, global_step, model_dir, vocabulary,
num_examples=sum(len(cex) for cex in cached_examples.values()))
else:
outputs = [
d.decode("utf-8") if isinstance(d, bytes) else d
for d in decode(estimator, input_fn, vocabulary, checkpoint_path)
]
for eval_dataset in eval_datasets:
# Extract the portion of decodes corresponding to this dataset
examples = cached_examples[eval_dataset.name]
dataset_size = len(examples)
predictions = [
eval_dataset.postprocess_fn(d, example=ex)
for d, ex in zip(outputs[:dataset_size], examples)
]
# Remove the used decodes.
del outputs[:dataset_size]
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
if output_eval_examples:
predictions_filename = os.path.join(
eval_summary_dir,
"{}_{}_predictions".format(eval_dataset.name, global_step),
)
write_lines_to_file(predictions, predictions_filename)
for metric_fn in eval_dataset.metric_fns:
summary = tf.Summary()
targets = cached_targets[eval_dataset.name]
metric_result = metric_fn(targets, predictions)
if isinstance(metric_result, tf.Summary):
tf.logging.info("Precomputed summary at step %d", global_step)
summary_writer.add_summary(metric_result, global_step)
else:
for metric_name, metric_value in metric_result.items():
tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
tf.logging.info("%s at step %d: %.3f", tag, global_step,
metric_value)
summary.value.add(tag=tag, simple_value=metric_value)
summary_writer.add_summary(summary, global_step)
summary_writer.flush()
# Only padding should remain.
expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size
if outputs and len(outputs) != expected_pad:
raise ValueError("{} padded outputs, {} expected.".format(
len(outputs), expected_pad))