def validate_search_space()

in nni/common/hpo_utils/validation.py [0:0]


def validate_search_space(
        search_space: Any,
        support_types: Optional[List[str]] = None,
        raise_exception: bool = False  # for now, in case false positive
    ) -> bool:

    if not raise_exception:
        try:
            validate_search_space(search_space, support_types, True)
            return True
        except ValueError as e:
            logging.getLogger(__name__).error(e.args[0])
        return False

    if support_types is None:
        support_types = common_search_space_types

    if not isinstance(search_space, dict):
        raise ValueError(f'search space is a {type(search_space).__name__}, expect a dict : {repr(search_space)}')

    for name, spec in search_space.items():
        if not isinstance(spec, dict):
            raise ValueError(f'search space "{name}" is a {type(spec).__name__}, expect a dict : {repr(spec)}')
        if '_type' not in spec or '_value' not in spec:
            raise ValueError(f'search space "{name}" does not have "_type" or "_value" : {spec}')
        type_ = spec['_type']
        if type_ not in support_types:
            raise ValueError(f'search space "{name}" has unsupported type "{type_}" : {spec}')
        args = spec['_value']
        if not isinstance(args, list):
            raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')

        if type_ == 'choice':
            if not all(isinstance(arg, (float, int, str)) for arg in args):
                # FIXME: need further check for each algorithm which types are actually supported
                # for now validation only prints warning so it doesn't harm
                if not isinstance(args[0], dict) or '_name' not in args[0]:  # not nested search space
                    raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}')
            continue

        if type_.startswith('q'):
            if len(args) != 3:
                raise ValueError(f'search space "{name}" ({type_}) must have 3 values : {spec}')
        else:
            if len(args) != 2:
                raise ValueError(f'search space "{name}" ({type_}) must have 2 values : {spec}')

        if type_ == 'randint':
            if not all(isinstance(arg, int) for arg in args):
                raise ValueError(f'search space "{name}" ({type_}) must have int values : {spec}')
        else:
            if not all(isinstance(arg, (float, int)) for arg in args):
                raise ValueError(f'search space "{name}" ({type_}) must have float values : {spec}')

        if 'normal' not in type_:
            if args[0] >= args[1]:
                raise ValueError(f'search space "{name}" ({type_}) must have high > low : {spec}')
            if 'log' in type_ and args[0] <= 0:
                raise ValueError(f'search space "{name}" ({type_}) must have low > 0 : {spec}')
        else:
            if args[1] <= 0:
                raise ValueError(f'search space "{name}" ({type_}) must have sigma > 0 : {spec}')

    return True