def eval_model()

in mesh_tensorflow/transformer/utils.py [0:0]


def eval_model(estimator,
               vocabulary,
               sequence_length,
               batch_size,
               dataset_split,
               model_dir,
               eval_dataset_fn,
               eval_summary_dir,
               eval_checkpoint_step,
               eval_with_score=False,
               output_eval_examples=True,
               eval_dir_suffix=None,
               score_with_estimator_fn=score_with_estimator):
  """Eval a Mesh-TF model.

  Args:
    estimator: an Estimator object or a callable that returns one.
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple
    sequence_length: a dict from feature-key to integer the (packed)
      sequence length, e.g. {"inputs": 512, "targets": 128}. May also be set to
      `None` to automatically compute the maximum length of the examples, which
      requires `estimator` to be a callable.
    batch_size: an integer, global batch size
    dataset_split: a string
    model_dir: a string, directory with the model.
    eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
      Must be provided for mode="eval". Should accept the following arguments:
        - sequence_length: an integer or a dict from feature-key to integer
          the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
      dataset.EvalDataset tuples are namedtuples with the following fields:
        - name: string, the task name
        - dataset_fn: function which returns a tf.data.Dataset of tokenized and
          padded examples. Must not require any arguments and must include the
          feature keys 'inputs' and 'targets_pretokenized'.
        - postprocess_fn: function which converts original targets to values
          that can be processed by a `metric_fn`.
        - list_of_metric_fns: list of metric functions with the call signature
          `metric_fn(targets, predictions)` which returns a dict mapping
          submetric names to scalar values. TensorBoard summaries and other tags
          will be written out using the submetric names.
    eval_summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    eval_checkpoint_step: int, list of ints, or None. If an int or list of ints,
      evaluation or inference will be run on the checkpoint files in `model_dir`
      whose global steps are closest to the global steps provided. If None and
      mode="eval", run eval continuously waiting for new checkpoints via
      `tf.train.checkpoints_iterator`.
    eval_with_score: bool, whether to evaluate using log likelihood scores of
      targets instead of decoded predictions.
    output_eval_examples: bool, whether to dump inputs, targets and predictions
      of the eval examples in plaintext to eval_summary_dir.
    eval_dir_suffix: string, if not None then will appended to the
      eval_summary_dir.
    score_with_estimator_fn: a function to run scoring with the estimator.
  """
  if eval_dataset_fn is None:
    raise ValueError("Must provide eval_dataset_fn through gin for eval.")
  if sequence_length is None and not callable(estimator):
    raise ValueError(
        "A callable must be passed for the estimator when automatically "
        "computing the sequence length.")

  eval_datasets = eval_dataset_fn(
      sequence_length=sequence_length,
      vocabulary=vocabulary,
      dataset_split=dataset_split,
  )

  valid_eval_datasets = []
  for eval_dataset in eval_datasets:
    if not eval_dataset.metric_fns:
      tf.logging.info("Skipping %s because metric_fns is empty",
                      eval_dataset.name)
      continue
    # Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples
    valid_eval_datasets.append(transformer_dataset.EvalDataset(*eval_dataset))
  eval_datasets = valid_eval_datasets

  if not eval_datasets:
    tf.logging.info(
        "All provided EvalDatasets have metric_fns=[]; eval is not possible.")
    return

  eval_summary_dir = eval_summary_dir or os.path.join(
      model_dir, "{}_eval".format(dataset_split))
  if eval_dir_suffix is not None:
    eval_summary_dir += "_{}".format(eval_dir_suffix)
  summary_writer = tf.summary.FileWriter(eval_summary_dir)

  # Pre-load in all of the targets once before entering continuous eval loop
  cached_targets = {}
  cached_examples = {}
  # Need to create a separate graph for loading in original targets
  # or else TF will complain that we modified the graph
  max_sequence_length = {"inputs": 0, "targets": 0}

  tf.logging.info("Caching evaluation examples.")
  with tf.Graph().as_default():
    for eval_dataset in eval_datasets:
      if eval_dataset.metric_fns:
        ds = eval_dataset.dataset_fn()
        # Create list of postprocessed text targets
        inputs = []
        targets = []
        examples = []
        for ex in tfds.as_numpy(ds):
          max_sequence_length["inputs"] = max(
              max_sequence_length["inputs"], len(ex["inputs"]))
          max_sequence_length["targets"] = max(
              max_sequence_length["targets"], len(ex["targets"]))
          examples.append(ex)
          if "inputs_pretokenized" in ex:
            inputs.append(ex["inputs_pretokenized"])
          if "targets_pretokenized" in ex:
            targets_pretokenized = ex["targets_pretokenized"]
            if isinstance(targets_pretokenized, bytes):
              targets_pretokenized = targets_pretokenized.decode("utf-8")
            targets.append(
                eval_dataset.postprocess_fn(
                    targets_pretokenized, example=ex, is_target=True)
            )
        if output_eval_examples:
          targets_filename = os.path.join(
              eval_summary_dir,
              "{}_targets".format(eval_dataset.name),
          )
          write_lines_to_file(targets, targets_filename)
          inputs_filename = os.path.join(eval_summary_dir,
                                         "{}_inputs".format(eval_dataset.name))
          write_lines_to_file(inputs, inputs_filename)

        cached_targets[eval_dataset.name] = targets
        cached_examples[eval_dataset.name] = examples
  if sequence_length is None:
    tf.logging.info("Setting sequence lengths to %s", max_sequence_length)
    sequence_length = max_sequence_length
    estimator = functools.partial(estimator, sequence_length=sequence_length)
  elif (sequence_length["inputs"] < max_sequence_length["inputs"] or
        sequence_length["targets"] < max_sequence_length["targets"]):
    tf.logging.warning(
        "Given sequence lengths are insufficient for some evaluation inputs or "
        "targets. These sequences will be truncated to fit, likely leading to "
        "sub-optimal results. Consider passing `None` for sequence_length to "
        "have them be automatically computed.\n Got: %s,\n Max Lengths: %s",
        sequence_length, max_sequence_length)
  elif (sequence_length["inputs"] > max_sequence_length["inputs"] or
        sequence_length["targets"] > max_sequence_length["targets"]):
    tf.logging.warning(
        "Given sequence lengths are longer than necessary for some evaluation "
        "inputs or targets, resulting in wasted computation. Consider passing "
        "`None` for sequence_length to have them be automatically computed.\n"
        " Got: %s,\n Max Lengths: %s",
        sequence_length, max_sequence_length)

  if callable(estimator):
    estimator = estimator()

  input_fn = _get_combined_dataset_input_fn(
      eval_datasets, batch_size, sequence_length, check_for_metrics=True)

  checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
  for checkpoint_path in checkpoint_paths:
    tf.logging.info("Checkpoint path %s", checkpoint_path)
    global_step = int(get_step_from_checkpoint_path(checkpoint_path))
    if eval_with_score:
      outputs, _ = score_with_estimator_fn(
          estimator, input_fn, global_step, model_dir, vocabulary,
          num_examples=sum(len(cex) for cex in cached_examples.values()))
    else:
      outputs = [
          d.decode("utf-8") if isinstance(d, bytes) else d
          for d in decode(estimator, input_fn, vocabulary, checkpoint_path)
      ]
    for eval_dataset in eval_datasets:
      # Extract the portion of decodes corresponding to this dataset
      examples = cached_examples[eval_dataset.name]
      dataset_size = len(examples)
      predictions = [
          eval_dataset.postprocess_fn(d, example=ex)
          for d, ex in zip(outputs[:dataset_size], examples)
      ]
      # Remove the used decodes.
      del outputs[:dataset_size]

      global_step = int(get_step_from_checkpoint_path(checkpoint_path))

      if output_eval_examples:
        predictions_filename = os.path.join(
            eval_summary_dir,
            "{}_{}_predictions".format(eval_dataset.name, global_step),
        )
        write_lines_to_file(predictions, predictions_filename)

      for metric_fn in eval_dataset.metric_fns:
        summary = tf.Summary()
        targets = cached_targets[eval_dataset.name]
        metric_result = metric_fn(targets, predictions)
        if isinstance(metric_result, tf.Summary):
          tf.logging.info("Precomputed summary at step %d", global_step)
          summary_writer.add_summary(metric_result, global_step)
        else:
          for metric_name, metric_value in metric_result.items():
            tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
            tf.logging.info("%s at step %d: %.3f", tag, global_step,
                            metric_value)
            summary.value.add(tag=tag, simple_value=metric_value)
          summary_writer.add_summary(summary, global_step)
      summary_writer.flush()

    # Only padding should remain.
    expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size
    if outputs and len(outputs) != expected_pad:
      raise ValueError("{} padded outputs, {} expected.".format(
          len(outputs), expected_pad))