in plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py [0:0]
def create_optuna_distribution_from_override(override: Override) -> Any:
if not override.is_sweep_override():
return override.get_value_element_as_str()
value = override.value()
choices: List[CategoricalChoiceType] = []
if override.is_choice_sweep():
assert isinstance(value, ChoiceSweep)
for x in override.sweep_iterator(transformer=Transformer.encode):
assert isinstance(
x, (str, int, float, bool, type(None))
), f"A choice sweep expects str, int, float, bool, or None type. Got {type(x)}."
choices.append(x)
return CategoricalDistribution(choices)
if override.is_range_sweep():
assert isinstance(value, RangeSweep)
assert value.start is not None
assert value.stop is not None
if value.shuffle:
for x in override.sweep_iterator(transformer=Transformer.encode):
assert isinstance(
x, (str, int, float, bool, type(None))
), f"A choice sweep expects str, int, float, bool, or None type. Got {type(x)}."
choices.append(x)
return CategoricalDistribution(choices)
return IntUniformDistribution(
int(value.start), int(value.stop), step=int(value.step)
)
if override.is_interval_sweep():
assert isinstance(value, IntervalSweep)
assert value.start is not None
assert value.end is not None
if "log" in value.tags:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntLogUniformDistribution(int(value.start), int(value.end))
return LogUniformDistribution(value.start, value.end)
else:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntUniformDistribution(value.start, value.end)
return UniformDistribution(value.start, value.end)
raise NotImplementedError(f"{override} is not supported by Optuna sweeper.")