in adanet/core/estimator.py [0:0]
def __init__(self,
head,
subnetwork_generator,
max_iteration_steps,
ensemblers=None,
ensemble_strategies=None,
evaluator=None,
report_materializer=None,
metric_fn=None,
force_grow=False,
replicate_ensemble_in_training=False,
adanet_loss_decay=.9,
delay_secs_per_worker=5,
max_worker_delay_secs=60,
worker_wait_secs=5,
worker_wait_timeout_secs=7200,
model_dir=None,
report_dir=None,
config=None,
debug=False,
enable_ensemble_summaries=True,
enable_subnetwork_summaries=True,
global_step_combiner_fn=tf.math.reduce_mean,
max_iterations=None,
export_subnetwork_logits=False,
export_subnetwork_last_layer=True,
replay_config=None,
**kwargs):
if subnetwork_generator is None:
raise ValueError("subnetwork_generator can't be None.")
if max_iteration_steps is not None and max_iteration_steps <= 0.:
raise ValueError("max_iteration_steps must be > 0 or None.")
if max_iterations is not None and max_iterations <= 0.:
raise ValueError("max_iterations must be > 0 or None.")
is_distributed_training = config and config.num_worker_replicas > 1
is_model_dir_specified = model_dir or (config and config.model_dir)
if is_distributed_training and not is_model_dir_specified:
# A common model dir for the chief and workers is required for
# coordination during distributed training.
raise ValueError(
"For distributed training, a model_dir must be specified.")
self._subnetwork_generator = subnetwork_generator
# Overwrite superclass's assert that members are not overwritten in order
# to overwrite public methods. Note that we are doing something that is not
# explicitly supported by the Estimator API and may break in the future.
tf.estimator.Estimator._assert_members_are_not_overridden = staticmethod( # pylint: disable=protected-access
lambda _: None)
self._enable_v2_checkpoint = kwargs.pop("enable_v2_checkpoint", False)
self._evaluator = evaluator
self._report_materializer = report_materializer
self._force_grow = force_grow
self._delay_secs_per_worker = delay_secs_per_worker
self._max_worker_delay_secs = max_worker_delay_secs
self._worker_wait_secs = worker_wait_secs
self._worker_wait_timeout_secs = worker_wait_timeout_secs
self._max_iterations = max_iterations
self._replay_config = replay_config
# Added for backwards compatibility.
default_ensembler_args = [
"mixture_weight_type", "mixture_weight_initializer",
"warm_start_mixture_weights", "adanet_lambda", "adanet_beta", "use_bias"
]
default_ensembler_kwargs = {
k: v for k, v in kwargs.items() if k in default_ensembler_args
}
if default_ensembler_kwargs:
logging.warning(
"The following arguments have been moved to "
"`adanet.ensemble.ComplexityRegularizedEnsembler` which can be "
"specified in the `ensemblers` argument: %s",
sorted(default_ensembler_kwargs.keys()))
for key in default_ensembler_kwargs:
del kwargs[key]
# Experimental feature.
placement_strategy_arg = "experimental_placement_strategy"
placement_strategy = kwargs.pop(placement_strategy_arg, None)
if placement_strategy:
logging.warning(
"%s is an experimental feature. Its behavior is not guaranteed "
"to be backwards compatible.", placement_strategy_arg)
self._warm_start_settings = kwargs.get("warm_start_from")
# Monkey patch the default variable placement strategy that Estimator uses
# since it does not support workers having different graphs from the chief.
# TODO: Consider using `RunConfig.replace` with the new device_fn,
# but this can cause issues since RunConfig automatically parses TF_CONFIG
# environment variable.
with monkey_patch_default_variable_placement_strategy():
# This `Estimator` is responsible for bookkeeping across iterations, and
# for training the subnetworks in both a local and distributed setting.
# Subclassing improves future-proofing against new private methods being
# added to `tf.estimator.Estimator` that are expected to be callable by
# external functions, such as in b/110435640.
super(Estimator, self).__init__(
model_fn=self._create_model_fn(),
params={},
config=config,
model_dir=model_dir,
**kwargs)
if default_ensembler_kwargs and ensemblers:
raise ValueError("When specifying the `ensemblers` argument, "
"the following arguments must not be given: {}".format(
default_ensembler_kwargs.keys()))
if not ensemblers:
default_ensembler_kwargs["model_dir"] = self.model_dir
ensemblers = [
ensemble_lib.ComplexityRegularizedEnsembler(
**default_ensembler_kwargs)
]
# These are defined after base Estimator's init so that they can
# use the same temporary model_dir as the underlying Estimator even if
# model_dir is not provided.
self._use_tpu = kwargs.get("use_tpu", False)
ensemble_builder = _EnsembleBuilder(
head=head,
metric_fn=metric_fn,
use_tpu=self._use_tpu,
export_subnetwork_logits=export_subnetwork_logits,
export_subnetwork_last_layer=export_subnetwork_last_layer)
# TODO: Merge CandidateBuilder into SubnetworkManager.
candidate_builder = _CandidateBuilder(adanet_loss_decay=adanet_loss_decay)
subnetwork_manager = _SubnetworkManager(
head=head, metric_fn=metric_fn, use_tpu=self._use_tpu)
if not placement_strategy:
placement_strategy = distributed_lib.ReplicationStrategy()
self._iteration_builder = _IterationBuilder(
candidate_builder,
subnetwork_manager,
ensemble_builder,
ensemblers,
max_iteration_steps,
self._summary_maker,
global_step_combiner_fn,
placement_strategy,
replicate_ensemble_in_training,
use_tpu=self._use_tpu,
debug=debug,
enable_ensemble_summaries=enable_ensemble_summaries,
enable_subnetwork_summaries=enable_subnetwork_summaries,
enable_subnetwork_reports=self._report_materializer is not None)
self._ensemble_strategies = ensemble_strategies or [
ensemble_lib.GrowStrategy()
]
report_dir = report_dir or os.path.join(self._model_dir, "report")
self._report_accessor = _ReportAccessor(report_dir)