in src/python/tensorflow_cloud/tuner/tuner.py [0:0]
def update_trial(self,
trial_id: Text,
metrics: Mapping[Text, Union[int, float]],
step: int = 0):
"""Used by a worker to report the status of a trial."""
# Constructs the measurement.
# Adds the measurement of the objective functions to a trial.
elapsed_secs = time.time() - self._start_time
if elapsed_secs < 0 or step < 0:
raise ValueError(
"Both elapsed_secs and step must be non-negative.")
if elapsed_secs == 0 and step == 0:
raise ValueError(
"At least one of {elapsed_secs, step} must be positive")
metric_list = []
for ob in self._get_objective():
if ob.name not in metrics:
ob_name = ob.name.replace("val_", "")
if ob_name in metrics:
metric_list.append(
{"metric": ob_name,
"value": float(metrics.get(ob_name))}
)
tf.get_logger().info(
'Objective "{}" is not found in metrics.'.format(ob.name)
)
continue
metric_list.append(
{"metric": ob.name, "value": float(metrics.get(ob.name))}
)
self.service.report_intermediate_objective_value(
step, elapsed_secs, metric_list, trial_id
)
# Ensure metrics of trials are updated locally.
keras_tuner_trial = self.trials[trial_id]
for metric_name, metric_value in metrics.items():
if not keras_tuner_trial.metrics.exists(metric_name):
direction = metrics_tracking.infer_metric_direction(
metric_name)
keras_tuner_trial.metrics.register(
metric_name, direction=direction)
keras_tuner_trial.metrics.update(
metric_name, metric_value, step=step)
# Checks whether a trial should stop or not.
tf.get_logger().info("UpdateTrial: polls the stop decision.")
should_stop = self.service.should_trial_stop(trial_id)
if should_stop:
keras_tuner_trial.status = trial_module.TrialStatus.STOPPED
return keras_tuner_trial.status