in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]
def __init__(self,
model_fn=None,
model_dir=None,
config=None,
params=None,
use_tpu=True,
train_batch_size=None,
eval_batch_size=None,
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
export_to_tpu=True,
export_to_cpu=True,
warm_start_from=None,
embedding_config_spec=None,
export_saved_model_api_version=ExportSavedModelApiVersion.V1):
"""Constructs an `TPUEstimator` instance.
Args:
model_fn: Model function as required by `Estimator` which returns
EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
and `prediction_hooks` must not capure any TPU Tensor inside the
model_fn.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
params: An optional `dict` of hyper parameters that will be passed into
`input_fn` and `model_fn`. Keys are names of parameters, values are
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently, -
TPU training and evaluation respect this bit, but eval_on_tpu can
override execution of eval. See below.
train_batch_size: An int representing the global training batch size.
TPUEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`. Must be divisible by total
number of replicas.
eval_batch_size: An int representing evaluation batch size. Must be
divisible by total number of replicas.
predict_batch_size: An int representing the prediction batch size. Must be
divisible by total number of replicas.
batch_axis: A python tuple of int values describing how each tensor
produced by the Estimator `input_fn` should be split across the TPU
compute shards. For example, if your input_fn produced (images, labels)
where the images tensor is in `HWCN` format, your shard dimensions would
be [3, 0], where 3 corresponds to the `N` dimension of your images
Tensor, and 0 corresponds to the dimension along which to split the
labels to match up with the corresponding images. If None is supplied,
and per_host_input_for_training is True, batches will be sharded based
on the major dimension. If tpu_config.per_host_input_for_training is
False or `PER_HOST_V2`, batch_axis is ignored.
eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
export_to_tpu: If True, `export_saved_model()` exports a metagraph for
serving on TPU. Note that unsupported export modes such as EVAL will be
ignored. For those modes, only a CPU model will be exported. Currently,
export_to_tpu only supports PREDICT.
export_to_cpu: If True, `export_saved_model()` exports a metagraph for
serving on CPU.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings` object to fully
configure warm-starting. If the string filepath is provided instead of
a `WarmStartSettings`, then all variables are warm-started, and it is
assumed that vocabularies and Tensor names are unchanged.
embedding_config_spec: Optional EmbeddingConfigSpec instance to support
using TPU embedding.
export_saved_model_api_version: an integer: 1 or 2. 1 corresponds to V1,
2 corresponds to V2. (Defaults to V1). With
V1, `export_saved_model()` adds rewrite() and TPUPartitionedCallOp() for
user; while in v2, user is expected to add rewrite(),
TPUPartitionedCallOp() etc in their model_fn.
Raises:
ValueError: `params` has reserved keys already.
"""
if config is None or not isinstance(config, tpu_config.RunConfig):
raise ValueError(
'`config` must be provided with type `tpu_config.RunConfig`')
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
raise ValueError('{} are reserved keys but existed in params {}.'.format(
_RESERVED_PARAMS_KEYS, params))
if use_tpu:
# Perform some very basic validations. More validations will be found in
# _InternalTPUContext.
if train_batch_size is None:
raise ValueError('`train_batch_size` cannot be `None`')
util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')
if eval_batch_size is not None:
util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')
if predict_batch_size is not None:
util_lib.check_positive_integer(predict_batch_size,
'predict_batch_size')
if embedding_config_spec:
if (config.tpu_config.per_host_input_for_training not in (
tpu_config.InputPipelineConfig.PER_HOST_V1,
tpu_config.InputPipelineConfig.PER_HOST_V2)):
raise ValueError('Only PER_HOST_V1 and PER_HOST_V2 is supported when '
'using TPU Embedding; got {}.'.format(
config.tpu_config.per_host_input_for_training))
self._embedding_from_feature_columns = (
embedding_config_spec.feature_columns is not None)
if (not (use_tpu and eval_on_tpu) and embedding_config_spec and
embedding_config_spec.partition_strategy == 'mod'):
raise ValueError('Mod sharding of embedding tables not supported on '
'CPU.')
_tpu_estimator_gauge.get_cell().set(True)
# Verifies the model_fn signature according to Estimator framework.
estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access
# We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for
# config.model_dir.
model_function = self._augment_model_fn(model_fn, batch_axis)
# Overwrite log_step_count_steps to disable TensorLoggingHook and
# StepCounterHook from being created in Estimator. TPUEstimator already
# added equivalent hooks in _augment_model_fn above.
self._log_every_n_steps = config.log_step_count_steps
config = config.replace(log_step_count_steps=None)
# Passing non-None params as wrapped model_fn has it.
params = params or {}
super(TPUEstimator, self).__init__(
model_fn=model_function,
model_dir=model_dir,
config=config,
params=params,
warm_start_from=warm_start_from)
self._iterations_per_training_loop = util_lib.parse_iterations_per_loop(
self._config.tpu_config.iterations_per_loop)
# In absence of an explicit `log_every_n_secs` config, if the
# `iterations_per_loop` value is specified as time in seconds, enable
# logging every n secs based on the `iterations_per_loop` value. A trade-off
# avoiding API change on the current release.
# TODO(henrytan): add `log_every_n_secs` to RunConfig.
if self._iterations_per_training_loop.unit == 'seconds':
self._log_every_n_secs = self._iterations_per_training_loop.value
self._log_every_n_steps = None
elif self._iterations_per_training_loop.unit == 'count':
if self._log_every_n_steps is not None:
# Each session.run() lasts for iterations_per_loop. We can't log
# in-between a session.run(), and we can only log after the
# `iterations_per_loop` steps, so we can only approximate. If a user
# requests to log every N steps, we actually want to roughly log every
# N / `iterations_per_loop` steps to match the original intention.
self._log_every_n_steps = (
int(
math.ceil(
float(self._log_every_n_steps) /
self._iterations_per_training_loop.value)))
self._log_every_n_secs = None
else:
assert False, ('Invalid TPUConfig `iterations_per_loop` value. '
'Indicates a bug in `iterations_per_loop` '
'parsing.')
# All properties passed to _InternalTPUContext are immutable.
# pylint: disable=protected-access
self._ctx = tpu_context._get_tpu_context(self._config, train_batch_size,
eval_batch_size,
predict_batch_size, use_tpu,
eval_on_tpu, embedding_config_spec)
self._export_to_cpu = export_to_cpu
self._export_to_tpu = export_to_tpu
if not (isinstance(export_saved_model_api_version,
ExportSavedModelApiVersion)
or export_saved_model_api_version == 1
or export_saved_model_api_version == 2):
raise ValueError('export_saved_model_api_version should be 1 or 2; '
'got {}.'.format(
export_saved_model_api_version))
self._export_saved_model_api_version = export_saved_model_api_version
self._is_input_fn_invoked = None
self._rendezvous = {}