in tensorflow_decision_forests/keras/core.py [0:0]
def make_test_function(self):
"""Predictions for evaluation."""
@tf.function(experimental_relax_shapes=True)
def test_function_not_trained(iterator):
"""Evaluation of a non-trained model."""
next(iterator)
return {}
@tf.function(experimental_relax_shapes=True)
def step_function_trained(model, iterator):
"""Evaluation of a trained model.
The only difference with "super.make_test_function()" is that
"self.test_function" is not set.
Args:
model: Model object.
iterator: Iterator over dataset.
Returns:
Evaluation metrics.
"""
def run_step(data):
outputs = model.test_step(data)
with tf.control_dependencies(_minimum_control_deps(outputs)):
model._test_counter.assign_add(1) # pylint:disable=protected-access
return outputs
data = next(iterator)
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = _reduce_per_replica(
outputs, self.distribute_strategy, reduction="first")
return outputs
if self._is_trained:
# Special case if steps_per_execution is one.
if (self._steps_per_execution is None or
self._steps_per_execution.numpy().item() == 1):
def test_function(iterator):
"""Runs a test execution with a single step."""
return step_function_trained(self, iterator)
if not self.run_eagerly:
test_function = tf.function(
test_function, experimental_relax_shapes=True)
if self._cluster_coordinator:
return lambda it: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
test_function, args=(it,))
else:
return test_function
# If we're using a coordinator, use the value of self._steps_per_execution
# at the time the function is called/scheduled, and not when it is
# actually executed.
elif self._cluster_coordinator:
def test_function(iterator, steps_per_execution):
"""Runs a test execution with multiple steps."""
for _ in tf.range(steps_per_execution):
outputs = step_function_trained(self, iterator)
return outputs
if not self.run_eagerly:
test_function = tf.function(
test_function, experimental_relax_shapes=True)
return lambda it: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
test_function,
args=(it, self._steps_per_execution.value()))
else:
def test_function(iterator):
"""Runs a test execution with multiple steps."""
for _ in tf.range(self._steps_per_execution):
outputs = step_function_trained(self, iterator)
return outputs
if not self.run_eagerly:
test_function = tf.function(
test_function, experimental_relax_shapes=True)
return test_function
else:
return test_function_not_trained