def _create_common_objects()

in syne_tune/optimizer/schedulers/searchers/gp_searcher_factory.py [0:0]


def _create_common_objects(model=None, **kwargs):
    scheduler = kwargs['scheduler']
    is_hyperband = scheduler.startswith('hyperband')
    if model is None:
        model = 'gp_multitask'
    assert model == 'gp_multitask' or is_hyperband, \
        f"model = {model} only together with hyperband_* scheduler"
    hp_ranges = create_hp_ranges_for_warmstarting(**kwargs)
    random_seed, _kwargs = extract_random_seed(kwargs)
    # Skip optimization predicate for GP surrogate model
    if kwargs.get('opt_skip_num_max_resource', False) and is_hyperband:
        skip_optimization = SkipNoMaxResourcePredicate(
            init_length=kwargs['opt_skip_init_length'],
            max_resource=kwargs['max_epochs'])
    elif kwargs.get('opt_skip_period', 1) > 1:
        skip_optimization = SkipPeriodicallyPredicate(
            init_length=kwargs['opt_skip_init_length'],
            period=kwargs['opt_skip_period'])
    else:
        skip_optimization = None
    # Conversion from reward to metric (strictly decreasing) and back.
    # This is done only if the scheduler mode is 'max'.
    scheduler_mode = kwargs.get('scheduler_mode', 'min')
    if scheduler_mode == 'max':
        _map_reward = kwargs.get('map_reward', '1_minus_x')
        if isinstance(_map_reward, str):
            _map_reward_name = _map_reward
            assert _map_reward_name.endswith('minus_x'), \
                f"map_reward = {_map_reward_name} is not supported (use " + \
                "'minus_x' or '*_minus_x')"
            if _map_reward_name == 'minus_x':
                const = 0.0
            else:
                # Allow strings '*_minus_x', parse const for *
                # Example: '1_minus_x' => const = 1
                offset = len(_map_reward_name) - len('_minus_x')
                const = float(_map_reward_name[:offset])
            _map_reward: Optional[MapReward] = map_reward_const_minus_x(
                const=const)
        else:
            assert isinstance(_map_reward, MapReward), \
                "map_reward must either be string or of MapReward type"
    else:
        assert scheduler_mode == 'min', \
            f"scheduler_mode = {scheduler_mode}, must be in ('max', 'min')"
        _map_reward = kwargs.get('map_reward')
        if _map_reward is not None:
            logger.warning(
                f"Since scheduler_mode == 'min', map_reward = {_map_reward} is ignored")
            _map_reward = None
    result = {
        'hp_ranges': hp_ranges,
        'map_reward': _map_reward,
        'skip_optimization': skip_optimization,
    }
    if is_hyperband:
        epoch_range = (1, kwargs['max_epochs'])
        result['configspace_ext'] = ExtendedConfiguration(
            hp_ranges,
            resource_attr_key=kwargs['resource_attr'],
            resource_attr_range=epoch_range)

    # Create model factory
    if model == 'gp_multitask':
        result.update(_create_gp_standard_model(
            hp_ranges=hp_ranges,
            active_metric=INTERNAL_METRIC_NAME,
            random_seed=random_seed,
            is_hyperband=is_hyperband,
            **_kwargs))
    else:
        result.update(_create_gp_additive_model(
            model=model,
            hp_ranges=hp_ranges,
            active_metric=INTERNAL_METRIC_NAME,
            random_seed=random_seed,
            configspace_ext=result['configspace_ext'],
            **_kwargs))

    return result