in src/python/tensorflow_cloud/tuner/tuner.py [0:0]
def __init__(
self,
hypermodel: Union[hypermodel_module.HyperModel,
Callable[[hp_module.HyperParameters],
tf.keras.Model]],
project_id: Text,
region: Text,
directory: Text,
objective: Optional[Union[Text, oracle_module.Objective]] = None,
hyperparameters: Optional[hp_module.HyperParameters] = None,
study_config: Optional[Dict[Text, Any]] = None,
max_trials: Optional[int] = None,
study_id: Optional[Text] = None,
container_uri: Optional[Text] = None,
replica_config: Optional[machine_config.MachineConfig] = None,
replica_count: Optional[int] = 1,
**kwargs):
"""Constructor.
Args:
hypermodel: Instance of HyperModel class (or callable that takes
hyperparameters and returns a Model instance).
project_id: A GCP project id.
region: A GCP region. e.g. 'us-central1'.
directory: The Google Cloud Storage path for logs and checkpoints.
objective: Name of model metric to minimize or maximize, e.g.
"val_accuracy".
hyperparameters: Can be used to override (or register in advance)
hyperparameters in the search space.
study_config: Study configuration for Vizier service.
max_trials: Total number of trials (model configurations) to test at
most. Note that the oracle may interrupt the search before
`max_trials` models have been tested if the search space has
been exhausted.
study_id: An identifier of the study. The full study name will be
projects/{project_id}/locations/{region}/studies/{study_id}.
container_uri: Base image to use for AI Platform Training. This
image must follow cloud_fit image with a cloud_fit.remote() as
entry point. Refer to cloud_fit documentation for more details
at tensorflow_cloud/tuner/cloud_fit_readme.md.
replica_config: Optional `MachineConfig` that represents the
configuration for the general workers in a distribution cluster.
Defaults is None and mapped to a standard CPU config such as
`tensorflow_cloud.core.COMMON_MACHINE_CONFIGS.CPU`.
replica_count: Optional integer that represents the total number of
workers in a distribution cluster including a chief worker. Has
to be one or more.
**kwargs: Keyword arguments relevant to all `Tuner` subclasses.
Please see the docstring for `Tuner`.
Raises:
ValueError: If directory is not a valid Google Cloud Storage path.
"""
self._project_id = project_id
self._region = region
# Replica count and config are validated at the time of job_spec
# creation job_spec changes for each trial hence it can not be defined
# here.
self._replica_count = replica_count
self._replica_config = replica_config
if replica_config:
self._replica_config = replica_config
else:
self._replica_config = machine_config.COMMON_MACHINE_CONFIGS["CPU"]
# Setting AI Platform Training runtime configurations. User can create
# a new tuner using the same study id if they need to change any of the
# parameters below, however since this is not a common use case, we are
# adding them to the constructor instead of search parameters.
self._container_uri = container_uri
# Verify that directory is set to a valid GCS path.
if not directory.startswith("gs://"):
raise ValueError(
"Directory must be a valid Google Cloud Storage path.")
oracle = CloudOracle(
project_id=project_id,
region=region,
objective=objective,
hyperparameters=hyperparameters,
study_config=study_config,
max_trials=max_trials,
study_id=study_id,
)
super(DistributingCloudTuner, self,).__init__(
oracle=oracle, hypermodel=hypermodel, **kwargs
)
# If study_id is not provided, CloudOracle creates one. Setting the
# study_id to what CloudOracle generates, to ensure they are the same.
if study_id:
self._study_id = study_id
else:
self._study_id = oracle.study_id
self.directory = directory