def make_test_function()

in tensorflow_decision_forests/keras/core.py [0:0]


  def make_test_function(self):
    """Predictions for evaluation."""

    @tf.function(experimental_relax_shapes=True)
    def test_function_not_trained(iterator):
      """Evaluation of a non-trained model."""

      next(iterator)
      return {}

    @tf.function(experimental_relax_shapes=True)
    def step_function_trained(model, iterator):
      """Evaluation of a trained model.

      The only difference with "super.make_test_function()" is that
      "self.test_function" is not set.

      Args:
        model: Model object.
        iterator: Iterator over dataset.

      Returns:
        Evaluation metrics.
      """

      def run_step(data):
        outputs = model.test_step(data)
        with tf.control_dependencies(_minimum_control_deps(outputs)):
          model._test_counter.assign_add(1)  # pylint:disable=protected-access
        return outputs

      data = next(iterator)
      outputs = model.distribute_strategy.run(run_step, args=(data,))
      outputs = _reduce_per_replica(
          outputs, self.distribute_strategy, reduction="first")
      return outputs

    if self._is_trained:
      # Special case if steps_per_execution is one.
      if (self._steps_per_execution is None or
          self._steps_per_execution.numpy().item() == 1):

        def test_function(iterator):
          """Runs a test execution with a single step."""
          return step_function_trained(self, iterator)

        if not self.run_eagerly:
          test_function = tf.function(
              test_function, experimental_relax_shapes=True)

        if self._cluster_coordinator:
          return lambda it: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
              test_function, args=(it,))
        else:
          return test_function

      # If we're using a coordinator, use the value of self._steps_per_execution
      # at the time the function is called/scheduled, and not when it is
      # actually executed.
      elif self._cluster_coordinator:

        def test_function(iterator, steps_per_execution):
          """Runs a test execution with multiple steps."""
          for _ in tf.range(steps_per_execution):
            outputs = step_function_trained(self, iterator)
          return outputs

        if not self.run_eagerly:
          test_function = tf.function(
              test_function, experimental_relax_shapes=True)

        return lambda it: self._cluster_coordinator.schedule(  # pylint: disable=g-long-lambda
            test_function,
            args=(it, self._steps_per_execution.value()))
      else:

        def test_function(iterator):
          """Runs a test execution with multiple steps."""
          for _ in tf.range(self._steps_per_execution):
            outputs = step_function_trained(self, iterator)
          return outputs

        if not self.run_eagerly:
          test_function = tf.function(
              test_function, experimental_relax_shapes=True)
        return test_function

    else:
      return test_function_not_trained