def __init__()

in adanet/core/estimator.py [0:0]


  def __init__(self,
               head,
               subnetwork_generator,
               max_iteration_steps,
               ensemblers=None,
               ensemble_strategies=None,
               evaluator=None,
               report_materializer=None,
               metric_fn=None,
               force_grow=False,
               replicate_ensemble_in_training=False,
               adanet_loss_decay=.9,
               delay_secs_per_worker=5,
               max_worker_delay_secs=60,
               worker_wait_secs=5,
               worker_wait_timeout_secs=7200,
               model_dir=None,
               report_dir=None,
               config=None,
               debug=False,
               enable_ensemble_summaries=True,
               enable_subnetwork_summaries=True,
               global_step_combiner_fn=tf.math.reduce_mean,
               max_iterations=None,
               export_subnetwork_logits=False,
               export_subnetwork_last_layer=True,
               replay_config=None,
               **kwargs):
    if subnetwork_generator is None:
      raise ValueError("subnetwork_generator can't be None.")
    if max_iteration_steps is not None and max_iteration_steps <= 0.:
      raise ValueError("max_iteration_steps must be > 0 or None.")
    if max_iterations is not None and max_iterations <= 0.:
      raise ValueError("max_iterations must be > 0 or None.")
    is_distributed_training = config and config.num_worker_replicas > 1
    is_model_dir_specified = model_dir or (config and config.model_dir)
    if is_distributed_training and not is_model_dir_specified:
      # A common model dir for the chief and workers is required for
      # coordination during distributed training.
      raise ValueError(
          "For distributed training, a model_dir must be specified.")

    self._subnetwork_generator = subnetwork_generator

    # Overwrite superclass's assert that members are not overwritten in order
    # to overwrite public methods. Note that we are doing something that is not
    # explicitly supported by the Estimator API and may break in the future.
    tf.estimator.Estimator._assert_members_are_not_overridden = staticmethod(  # pylint: disable=protected-access
        lambda _: None)

    self._enable_v2_checkpoint = kwargs.pop("enable_v2_checkpoint", False)
    self._evaluator = evaluator
    self._report_materializer = report_materializer

    self._force_grow = force_grow
    self._delay_secs_per_worker = delay_secs_per_worker
    self._max_worker_delay_secs = max_worker_delay_secs
    self._worker_wait_secs = worker_wait_secs
    self._worker_wait_timeout_secs = worker_wait_timeout_secs
    self._max_iterations = max_iterations
    self._replay_config = replay_config

    # Added for backwards compatibility.
    default_ensembler_args = [
        "mixture_weight_type", "mixture_weight_initializer",
        "warm_start_mixture_weights", "adanet_lambda", "adanet_beta", "use_bias"
    ]
    default_ensembler_kwargs = {
        k: v for k, v in kwargs.items() if k in default_ensembler_args
    }
    if default_ensembler_kwargs:
      logging.warning(
          "The following arguments have been moved to "
          "`adanet.ensemble.ComplexityRegularizedEnsembler` which can be "
          "specified in the `ensemblers` argument: %s",
          sorted(default_ensembler_kwargs.keys()))
    for key in default_ensembler_kwargs:
      del kwargs[key]

    # Experimental feature.
    placement_strategy_arg = "experimental_placement_strategy"
    placement_strategy = kwargs.pop(placement_strategy_arg, None)
    if placement_strategy:
      logging.warning(
          "%s is an experimental feature. Its behavior is not guaranteed "
          "to be backwards compatible.", placement_strategy_arg)

    self._warm_start_settings = kwargs.get("warm_start_from")

    # Monkey patch the default variable placement strategy that Estimator uses
    # since it does not support workers having different graphs from the chief.
    # TODO: Consider using `RunConfig.replace` with the new device_fn,
    # but this can cause issues since RunConfig automatically parses TF_CONFIG
    # environment variable.
    with monkey_patch_default_variable_placement_strategy():
      # This `Estimator` is responsible for bookkeeping across iterations, and
      # for training the subnetworks in both a local and distributed setting.
      # Subclassing improves future-proofing against new private methods being
      # added to `tf.estimator.Estimator` that are expected to be callable by
      # external functions, such as in b/110435640.
      super(Estimator, self).__init__(
          model_fn=self._create_model_fn(),
          params={},
          config=config,
          model_dir=model_dir,
          **kwargs)

    if default_ensembler_kwargs and ensemblers:
      raise ValueError("When specifying the `ensemblers` argument, "
                       "the following arguments must not be given: {}".format(
                           default_ensembler_kwargs.keys()))
    if not ensemblers:
      default_ensembler_kwargs["model_dir"] = self.model_dir
      ensemblers = [
          ensemble_lib.ComplexityRegularizedEnsembler(
              **default_ensembler_kwargs)
      ]

    # These are defined after base Estimator's init so that they can
    # use the same temporary model_dir as the underlying Estimator even if
    # model_dir is not provided.
    self._use_tpu = kwargs.get("use_tpu", False)
    ensemble_builder = _EnsembleBuilder(
        head=head,
        metric_fn=metric_fn,
        use_tpu=self._use_tpu,
        export_subnetwork_logits=export_subnetwork_logits,
        export_subnetwork_last_layer=export_subnetwork_last_layer)

    # TODO: Merge CandidateBuilder into SubnetworkManager.
    candidate_builder = _CandidateBuilder(adanet_loss_decay=adanet_loss_decay)
    subnetwork_manager = _SubnetworkManager(
        head=head, metric_fn=metric_fn, use_tpu=self._use_tpu)
    if not placement_strategy:
      placement_strategy = distributed_lib.ReplicationStrategy()
    self._iteration_builder = _IterationBuilder(
        candidate_builder,
        subnetwork_manager,
        ensemble_builder,
        ensemblers,
        max_iteration_steps,
        self._summary_maker,
        global_step_combiner_fn,
        placement_strategy,
        replicate_ensemble_in_training,
        use_tpu=self._use_tpu,
        debug=debug,
        enable_ensemble_summaries=enable_ensemble_summaries,
        enable_subnetwork_summaries=enable_subnetwork_summaries,
        enable_subnetwork_reports=self._report_materializer is not None)
    self._ensemble_strategies = ensemble_strategies or [
        ensemble_lib.GrowStrategy()
    ]

    report_dir = report_dir or os.path.join(self._model_dir, "report")
    self._report_accessor = _ReportAccessor(report_dir)