def __init__()

in adanet/core/tpu_estimator.py [0:0]


  def __init__(self,
               head,
               subnetwork_generator,
               max_iteration_steps,
               ensemblers=None,
               ensemble_strategies=None,
               evaluator=None,
               report_materializer=None,
               metric_fn=None,
               force_grow=False,
               replicate_ensemble_in_training=False,
               adanet_loss_decay=.9,
               model_dir=None,
               report_dir=None,
               config=None,
               use_tpu=True,
               eval_on_tpu=True,
               export_to_tpu=True,
               train_batch_size=None,
               eval_batch_size=None,
               predict_batch_size=None,
               embedding_config_spec=None,
               debug=False,
               enable_ensemble_summaries=True,
               enable_subnetwork_summaries=True,
               export_subnetwork_logits=False,
               export_subnetwork_last_layer=True,
               global_step_combiner_fn=tf.math.reduce_mean,
               max_iterations=None,
               replay_config=None,
               add_predict_batch_config=True,
               **kwargs):
    self._use_tpu = use_tpu
    if not self._use_tpu:
      logging.warning(
          "This adanet.TPUEstimator is meant to be used for running on TPU. "
          "If you want to run on CPU/GPU, use adanet.Estimator instead.")
    # TPUEstimator modifies config under the hood. We keep track of it here so
    # we can use it from _create_temp_run_config.
    self._original_config = config or tf_compat.v1.estimator.tpu.RunConfig()
    self._eval_on_tpu = eval_on_tpu if self._use_tpu else False
    self._export_to_tpu = export_to_tpu
    self._train_batch_size = train_batch_size or 0
    self._eval_batch_size = eval_batch_size or train_batch_size or 0
    self._predict_batch_size = (
        predict_batch_size or eval_batch_size or train_batch_size or 0)
    self._embedding_config_spec = embedding_config_spec
    self._add_predict_batch_config = add_predict_batch_config
    if self._embedding_config_spec:
      logging.warning(
          "TPU does not support inference with TPUEmbedding. Force setting "
          "`export_to_tpu=False` so no TPU SavedModel will be exported.")
      self._export_to_tpu = False

    from tensorflow_estimator.python.estimator.tpu import tpu_estimator  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
    super(TPUEstimator, self).__init__(
        head=head,
        subnetwork_generator=subnetwork_generator,
        max_iteration_steps=max_iteration_steps,
        ensemblers=ensemblers,
        ensemble_strategies=ensemble_strategies,
        evaluator=evaluator,
        report_materializer=report_materializer,
        metric_fn=metric_fn,
        force_grow=force_grow,
        replicate_ensemble_in_training=replicate_ensemble_in_training,
        adanet_loss_decay=adanet_loss_decay,
        model_dir=model_dir,
        report_dir=report_dir,
        config=self._original_config,
        use_tpu=self._use_tpu,
        eval_on_tpu=self._eval_on_tpu,
        export_to_tpu=self._export_to_tpu,
        export_saved_model_api_version=(
            tpu_estimator.ExportSavedModelApiVersion.V2),
        train_batch_size=self._train_batch_size,
        eval_batch_size=self._eval_batch_size,
        predict_batch_size=self._predict_batch_size,
        embedding_config_spec=self._embedding_config_spec,
        debug=debug,
        enable_ensemble_summaries=enable_ensemble_summaries,
        enable_subnetwork_summaries=enable_subnetwork_summaries,
        export_subnetwork_logits=export_subnetwork_logits,
        export_subnetwork_last_layer=export_subnetwork_last_layer,
        global_step_combiner_fn=global_step_combiner_fn,
        max_iterations=max_iterations,
        replay_config=replay_config,
        **kwargs)