in syne_tune/optimizer/schedulers/hyperband.py [0:0]
def __init__(self, config_space, **kwargs):
# Before we can call the superclass constructor, we need to set a few
# members (see also `_extend_search_options`).
# To do this properly, we first check values and impute defaults for
# `kwargs`.
kwargs = check_and_merge_defaults(
kwargs, set(), _DEFAULT_OPTIONS, _CONSTRAINTS,
dict_name='scheduler_options')
scheduler_type = kwargs['type']
self.scheduler_type = scheduler_type
self._resource_attr = kwargs['resource_attr']
self._rung_system_kwargs = kwargs['rung_system_kwargs']
self._cost_attr = self._rung_system_kwargs['cost_attr']
# Superclass constructor
resume = kwargs['resume']
kwargs['resume'] = False # Cannot be done in superclass
super().__init__(config_space, **filter_by_key(kwargs, _ARGUMENT_KEYS))
assert self.max_t is not None, \
"Either max_t must be specified, or it has to be specified as " + \
"config_space['epochs'], config_space['max_t'], " + \
"config_space['max_epochs']"
# If rung_levels is given, grace_period and reduction_factor are ignored
rung_levels = kwargs.get('rung_levels')
if rung_levels is not None:
assert isinstance(rung_levels, list)
if ('grace_period' in kwargs) or ('reduction_factor' in kwargs):
logger.warning(
"Since rung_levels is given, the values grace_period = "
f"{kwargs.get('grace_period')} and reduction_factor = "
f"{kwargs.get('reduction_factor')} are ignored!")
rung_levels = _get_rung_levels(
rung_levels, grace_period=kwargs['grace_period'],
reduction_factor=kwargs['reduction_factor'], max_t=self.max_t)
brackets = kwargs['brackets']
do_snapshots = kwargs['do_snapshots']
assert (not do_snapshots) or (scheduler_type == 'stopping'), \
"Snapshots are supported only for type = 'stopping'"
rung_system_per_bracket = kwargs['rung_system_per_bracket']
self.terminator = HyperbandBracketManager(
scheduler_type, self._resource_attr, self.metric, self.mode,
self.max_t, rung_levels, brackets, rung_system_per_bracket,
cost_attr=self._total_cost_attr(),
random_seed=self.random_seed_generator(),
rung_system_kwargs=self._rung_system_kwargs)
self.do_snapshots = do_snapshots
self.searcher_data = kwargs['searcher_data']
self._register_pending_myopic = kwargs['register_pending_myopic']
# _active_trials:
# Maintains a snapshot of currently active tasks (running or paused),
# needed by several features (for example, searcher_data ==
# 'rungs_and_last', or for providing a snapshot to the searcher).
# Maps trial_id to dict, with fields:
# - config
# - time_stamp: Time when task was started, or when last recent
# result was reported
# - reported_result: Last recent reported result, or None (task was
# started, but did not report anything yet.
# Note: Only contains attributes self.metric and
# self._resource_attr).
# - bracket: Bracket number
# - keep_case: Boolean flag. Relevant only if searcher_data ==
# 'rungs_and_last'. See _run_reporter
# - running: Is the trial running? Otherwise, it is paused. Note we
# keep paused trials in _active_trials, updating the entry when a
# trial is resumed. Allows to retrieve config
# - largest_update_resource: Largest resource level for which the
# searcher was updated, or None
self._active_trials = dict()
# _cost_offset:
# Is used for promotion-based (pause/resume) scheduling if the eval
# function reports cost values. For a trial which has been paused at
# least once, this records the sum of costs for reaching its last
# recent milestone.
self._cost_offset = dict()
if resume:
checkpoint = kwargs.get('checkpoint')
assert checkpoint is not None, \
"Need checkpoint to be set if resume = True"
if os.path.isfile(checkpoint):
raise NotImplementedError()
# TODO! Need load
# self.load_state_dict(load(checkpoint))
else:
msg = f'checkpoint path {checkpoint} is not available for resume.'
logger.exception(msg)
raise FileExistsError(msg)