in syne_tune/optimizer/schedulers/hyperband.py [0:0]
def on_trial_result(self, trial: Trial, result: Dict) -> str:
self._check_result(result)
trial_id = str(trial.trial_id)
debug_log = self.searcher.debug_log
trial_decision = SchedulerDecision.CONTINUE
if len(result) == 0:
# An empty dict should just be skipped
if debug_log is not None:
logger.info(
f"trial_id {trial_id}: Skipping empty dict received "
"from reporter")
else:
# Time since start of experiment
time_since_start = self._elapsed_time()
do_update = False
config = self._preprocess_config(trial.config)
cost_and_promotion = (self._cost_attr in result) and \
self._does_pause_resume()
if cost_and_promotion:
# Trial may have paused/resumed before, so need to add cost
# offset from these
cost_offset = self._cost_offset.get(trial_id, 0)
result[self._total_cost_attr()] = \
result[self._cost_attr] + cost_offset
# We may receive a report from a trial which has been stopped or
# paused before. In such a case, we override trial_decision to be
# STOP or PAUSE as before, so the report is not taken into account
# by the scheduler. The report is sent to the searcher, but with
# update=False. This means that the report is registered, but cannot
# influence any decisions.
if trial_id not in self._active_trials:
# Trial not in self._active_trials anymore, so must have been
# stopped
trial_decision = SchedulerDecision.STOP
logger.warning(
f"trial_id {trial_id}: Was STOPPED, but receives another "
f"report {result}\nThis report is ignored")
elif not self._active_trials[trial_id]['running']:
# Trial must have been paused before
trial_decision = SchedulerDecision.PAUSE
logger.warning(
f"trial_id {trial_id}: Was PAUSED, but receives another "
f"report {result}\nThis report is ignored")
else:
task_info = self.terminator.on_task_report(trial_id, result)
task_continues = task_info['task_continues']
milestone_reached = task_info['milestone_reached']
if cost_and_promotion:
if milestone_reached:
# Trial reached milestone and will pause there: Update
# cost offset
self._cost_offset[trial_id] = \
result[self._total_cost_attr()]
elif task_info.get('ignore_data', False):
# For a resumed trial, the report is for resource <=
# resume_from, where resume_from < milestone. This
# happens if checkpointing is not implemented and a
# resumed trial has to start from scratch, publishing
# results all the way up to resume_from. In this case,
# we can erase the `_cost_offset` entry, since the
# instanteneous cost reported by the trial does not
# have any offset.
if self._cost_offset[trial_id] > 0:
logger.info(
f"trial_id {trial_id}: Resumed trial seems to have been " +
"started from scratch (no checkpointing?), so we erase " +
"the cost offset.")
self._cost_offset[trial_id] = 0
# Update searcher and register pending
do_update = self._update_searcher(
trial_id, config, result, task_info)
# Change snapshot entry for task
# Note: This must not be done above, because what _update_searcher
# is doing, depends on the entry *before* its update here.
# Note: result may contain all sorts of extra info.
# All we need to maintain in the snapshot are metric and
# resource level.
# 'keep_case' entry (only used if searcher_data ==
# 'rungs_and_last'): The result is kept in the dataset iff
# milestone_reached == True (i.e., we are at a rung level).
# Otherwise, it is removed once _update_searcher is called for
# the next recent result.
resource = int(result[self._resource_attr])
self._active_trials[trial_id].update({
'time_stamp': time_since_start,
'reported_result': {
self.metric: result[self.metric],
self._resource_attr: resource},
'keep_case': milestone_reached})
if do_update:
largest_update_resource = self._active_trials[trial_id][
'largest_update_resource']
if largest_update_resource is None:
largest_update_resource = resource - 1
assert largest_update_resource <= resource, \
f"Internal error (trial_id {trial_id}): " +\
f"on_trial_result called with resource = {resource}, " +\
f"but largest_update_resource = {largest_update_resource}"
if resource == largest_update_resource:
do_update = False # Do not update again
else:
self._active_trials[trial_id][
'largest_update_resource'] = resource
if not task_continues:
if (not self._does_pause_resume()) or resource >= self.max_t:
trial_decision = SchedulerDecision.STOP
act_str = 'Terminating'
else:
trial_decision = SchedulerDecision.PAUSE
act_str = 'Pausing'
self._cleanup_trial(trial_id)
if debug_log is not None:
if not task_continues:
logger.info(
f"trial_id {trial_id}: {act_str} evaluation "
f"at {resource}")
elif milestone_reached:
msg = f"trial_id {trial_id}: Reaches {resource}, continues"
next_milestone = task_info.get('next_milestone')
if next_milestone is not None:
msg += f" to {next_milestone}"
logger.info(msg)
self.searcher.on_trial_result(
trial_id, config, result=result, update=do_update)
# Extra info in debug mode
log_msg = f"trial_id {trial_id} (metric = {result[self.metric]:.3f}"
for k, is_float in (
(self._resource_attr, False), ('elapsed_time', True)):
if k in result:
if is_float:
log_msg += f", {k} = {result[k]:.2f}"
else:
log_msg += f", {k} = {result[k]}"
log_msg += f"): decision = {trial_decision}"
logger.debug(log_msg)
return trial_decision