ax/modelbridge/transforms/centered_unit_x.py (59 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from ax.core.observation import ObservationData, ObservationFeatures from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc if TYPE_CHECKING: # import as module to make sphinx-autodoc-typehints happy from ax import modelbridge as modelbridge_module # noqa F401 # pragma: no cover class CenteredUnitX(Transform): """Map X to [-1, 1]^d for RangeParameter of type float and not log scale. Currently does not support linear constraints, but could in the future be adjusted to transform them too, since this is a linear operation. Transform is done in-place. """ def __init__( self, search_space: SearchSpace, observation_features: List[ObservationFeatures], observation_data: List[ObservationData], modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, config: Optional[TConfig] = None, ) -> None: # Identify parameters that should be transformed self.bounds: Dict[str, Tuple[float, float]] = {} for p_name, p in search_space.parameters.items(): if ( isinstance(p, RangeParameter) and p.parameter_type == ParameterType.FLOAT and not p.log_scale ): self.bounds[p_name] = (p.lower, p.upper) @copy_doc(Transform.transform_observation_features) def transform_observation_features( self, observation_features: List[ObservationFeatures] ) -> List[ObservationFeatures]: for obsf in observation_features: for p_name, (l, u) in self.bounds.items(): if p_name in obsf.parameters: # pyre: param is declared to have type `float` but is used # pyre-fixme[9]: as type `Optional[typing.Union[bool, float, str]]`. param: float = obsf.parameters[p_name] obsf.parameters[p_name] = -1 + 2 * (param - l) / (u - l) return observation_features @copy_doc(Transform.transform_search_space) def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: for p_name, p in search_space.parameters.items(): if p_name in self.bounds and isinstance(p, RangeParameter): p.update_range(lower=-1.0, upper=1.0) if p.target_value is not None: l, u = self.bounds[p_name] # pyre-fixme[58]: `-` is not supported for operand types # `Union[None, bool, float, int, str]` and `float`. new_tval = -1 + 2 * (p.target_value - l) / (u - l) p._target_value = new_tval for c in search_space.parameter_constraints: for p_name in c.constraint_dict: if p_name in self.bounds: raise ValueError("Does not support parameter constraints") return search_space @copy_doc(Transform.untransform_observation_features) def untransform_observation_features( self, observation_features: List[ObservationFeatures] ) -> List[ObservationFeatures]: for obsf in observation_features: for p_name, (l, u) in self.bounds.items(): # pyre: param is declared to have type `float` but is used as # pyre-fixme[9]: type `Optional[typing.Union[bool, float, str]]`. param: float = obsf.parameters[p_name] obsf.parameters[p_name] = ((param + 1) / 2) * (u - l) + l return observation_features