in adanet/core/tpu_estimator.py [0:0]
def __init__(self,
head,
subnetwork_generator,
max_iteration_steps,
ensemblers=None,
ensemble_strategies=None,
evaluator=None,
report_materializer=None,
metric_fn=None,
force_grow=False,
replicate_ensemble_in_training=False,
adanet_loss_decay=.9,
model_dir=None,
report_dir=None,
config=None,
use_tpu=True,
eval_on_tpu=True,
export_to_tpu=True,
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
embedding_config_spec=None,
debug=False,
enable_ensemble_summaries=True,
enable_subnetwork_summaries=True,
export_subnetwork_logits=False,
export_subnetwork_last_layer=True,
global_step_combiner_fn=tf.math.reduce_mean,
max_iterations=None,
replay_config=None,
add_predict_batch_config=True,
**kwargs):
self._use_tpu = use_tpu
if not self._use_tpu:
logging.warning(
"This adanet.TPUEstimator is meant to be used for running on TPU. "
"If you want to run on CPU/GPU, use adanet.Estimator instead.")
# TPUEstimator modifies config under the hood. We keep track of it here so
# we can use it from _create_temp_run_config.
self._original_config = config or tf_compat.v1.estimator.tpu.RunConfig()
self._eval_on_tpu = eval_on_tpu if self._use_tpu else False
self._export_to_tpu = export_to_tpu
self._train_batch_size = train_batch_size or 0
self._eval_batch_size = eval_batch_size or train_batch_size or 0
self._predict_batch_size = (
predict_batch_size or eval_batch_size or train_batch_size or 0)
self._embedding_config_spec = embedding_config_spec
self._add_predict_batch_config = add_predict_batch_config
if self._embedding_config_spec:
logging.warning(
"TPU does not support inference with TPUEmbedding. Force setting "
"`export_to_tpu=False` so no TPU SavedModel will be exported.")
self._export_to_tpu = False
from tensorflow_estimator.python.estimator.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
super(TPUEstimator, self).__init__(
head=head,
subnetwork_generator=subnetwork_generator,
max_iteration_steps=max_iteration_steps,
ensemblers=ensemblers,
ensemble_strategies=ensemble_strategies,
evaluator=evaluator,
report_materializer=report_materializer,
metric_fn=metric_fn,
force_grow=force_grow,
replicate_ensemble_in_training=replicate_ensemble_in_training,
adanet_loss_decay=adanet_loss_decay,
model_dir=model_dir,
report_dir=report_dir,
config=self._original_config,
use_tpu=self._use_tpu,
eval_on_tpu=self._eval_on_tpu,
export_to_tpu=self._export_to_tpu,
export_saved_model_api_version=(
tpu_estimator.ExportSavedModelApiVersion.V2),
train_batch_size=self._train_batch_size,
eval_batch_size=self._eval_batch_size,
predict_batch_size=self._predict_batch_size,
embedding_config_spec=self._embedding_config_spec,
debug=debug,
enable_ensemble_summaries=enable_ensemble_summaries,
enable_subnetwork_summaries=enable_subnetwork_summaries,
export_subnetwork_logits=export_subnetwork_logits,
export_subnetwork_last_layer=export_subnetwork_last_layer,
global_step_combiner_fn=global_step_combiner_fn,
max_iterations=max_iterations,
replay_config=replay_config,
**kwargs)