def __init__()

in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]


  def __init__(self,
               model_fn=None,
               model_dir=None,
               config=None,
               params=None,
               use_tpu=True,
               train_batch_size=None,
               eval_batch_size=None,
               predict_batch_size=None,
               batch_axis=None,
               eval_on_tpu=True,
               export_to_tpu=True,
               export_to_cpu=True,
               warm_start_from=None,
               embedding_config_spec=None,
               export_saved_model_api_version=ExportSavedModelApiVersion.V1):
    """Constructs an `TPUEstimator` instance.

    Args:
      model_fn: Model function as required by `Estimator` which returns
        EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
        and `prediction_hooks` must not capure any TPU Tensor inside the
        model_fn.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same. If
        both are `None`, a temporary directory will be used.
      config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
      params: An optional `dict` of hyper parameters that will be passed into
        `input_fn` and `model_fn`.  Keys are names of parameters, values are
        basic python types. There are reserved keys for `TPUEstimator`,
        including 'batch_size'.
      use_tpu: A bool indicating whether TPU support is enabled. Currently, -
        TPU training and evaluation respect this bit, but eval_on_tpu can
        override execution of eval. See below.
      train_batch_size: An int representing the global training batch size.
        TPUEstimator transforms this global batch size to a per-shard batch
        size, as params['batch_size'], when calling `input_fn` and `model_fn`.
        Cannot be `None` if `use_tpu` is `True`. Must be divisible by total
        number of replicas.
      eval_batch_size: An int representing evaluation batch size. Must be
        divisible by total number of replicas.
      predict_batch_size: An int representing the prediction batch size. Must be
        divisible by total number of replicas.
      batch_axis: A python tuple of int values describing how each tensor
        produced by the Estimator `input_fn` should be split across the TPU
        compute shards. For example, if your input_fn produced (images, labels)
        where the images tensor is in `HWCN` format, your shard dimensions would
        be [3, 0], where 3 corresponds to the `N` dimension of your images
        Tensor, and 0 corresponds to the dimension along which to split the
        labels to match up with the corresponding images. If None is supplied,
        and per_host_input_for_training is True, batches will be sharded based
        on the major dimension. If tpu_config.per_host_input_for_training is
        False or `PER_HOST_V2`, batch_axis is ignored.
      eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
        model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
      export_to_tpu: If True, `export_saved_model()` exports a metagraph for
        serving on TPU. Note that unsupported export modes such as EVAL will be
        ignored. For those modes, only a CPU model will be exported. Currently,
        export_to_tpu only supports PREDICT.
      export_to_cpu: If True, `export_saved_model()` exports a metagraph for
        serving on CPU.
      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
        warm-start from, or a `tf.estimator.WarmStartSettings` object to fully
        configure warm-starting.  If the string filepath is provided instead of
        a `WarmStartSettings`, then all variables are warm-started, and it is
        assumed that vocabularies and Tensor names are unchanged.
      embedding_config_spec: Optional EmbeddingConfigSpec instance to support
        using TPU embedding.
      export_saved_model_api_version: an integer: 1 or 2. 1 corresponds to V1,
        2 corresponds to V2. (Defaults to V1). With
        V1, `export_saved_model()` adds rewrite() and TPUPartitionedCallOp() for
        user; while in v2, user is expected to add rewrite(),
        TPUPartitionedCallOp() etc in their model_fn.

    Raises:
      ValueError: `params` has reserved keys already.
    """
    if config is None or not isinstance(config, tpu_config.RunConfig):
      raise ValueError(
          '`config` must be provided with type `tpu_config.RunConfig`')

    if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
      raise ValueError('{} are reserved keys but existed in params {}.'.format(
          _RESERVED_PARAMS_KEYS, params))

    if use_tpu:
      # Perform some very basic validations. More validations will be found in
      # _InternalTPUContext.
      if train_batch_size is None:
        raise ValueError('`train_batch_size` cannot be `None`')
      util_lib.check_positive_integer(train_batch_size, 'train_batch_size')

      if (config.tpu_config.per_host_input_for_training is
          tpu_config.InputPipelineConfig.PER_SHARD_V1 and
          config.tpu_config.num_cores_per_replica):
        raise ValueError(
            'Model parallelism only supports per host input for training. '
            'Please adjust TPURunconfig.per_host_input_for_training.')

      if eval_batch_size is not None:
        util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')

      if predict_batch_size is not None:
        util_lib.check_positive_integer(predict_batch_size,
                                        'predict_batch_size')

      if embedding_config_spec:
        if (config.tpu_config.per_host_input_for_training not in (
            tpu_config.InputPipelineConfig.PER_HOST_V1,
            tpu_config.InputPipelineConfig.PER_HOST_V2)):
          raise ValueError('Only PER_HOST_V1 and PER_HOST_V2 is supported when '
                           'using TPU Embedding; got {}.'.format(
                               config.tpu_config.per_host_input_for_training))
        self._embedding_from_feature_columns = (
            embedding_config_spec.feature_columns is not None)

    if (not (use_tpu and eval_on_tpu) and embedding_config_spec and
        embedding_config_spec.partition_strategy == 'mod'):
      raise ValueError('Mod sharding of embedding tables not supported on '
                       'CPU.')
    _tpu_estimator_gauge.get_cell().set(True)
    # Verifies the model_fn signature according to Estimator framework.
    estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
    # We cannot store config and params in this constructor as parent
    # constructor might change them, such as assigning a temp dir for
    # config.model_dir.
    model_function = self._augment_model_fn(model_fn, batch_axis)

    # Overwrite log_step_count_steps to disable TensorLoggingHook and
    # StepCounterHook from being created in Estimator. TPUEstimator already
    # added equivalent hooks in _augment_model_fn above.
    self._log_every_n_steps = config.log_step_count_steps
    config = config.replace(log_step_count_steps=None)

    # Passing non-None params as wrapped model_fn has it.
    params = params or {}
    super(TPUEstimator, self).__init__(
        model_fn=model_function,
        model_dir=model_dir,
        config=config,
        params=params,
        warm_start_from=warm_start_from)
    self._iterations_per_training_loop = util_lib.parse_iterations_per_loop(
        self._config.tpu_config.iterations_per_loop)
    # In absence of an explicit `log_every_n_secs` config, if the
    # `iterations_per_loop` value is specified as time in seconds, enable
    # logging every n secs based on the `iterations_per_loop` value. A trade-off
    # avoiding API change on the current release.
    # TODO(henrytan): add `log_every_n_secs` to RunConfig.
    if self._iterations_per_training_loop.unit == 'seconds':
      self._log_every_n_secs = self._iterations_per_training_loop.value
      self._log_every_n_steps = None
    elif self._iterations_per_training_loop.unit == 'count':
      if self._log_every_n_steps is not None:
        # Each session.run() lasts for iterations_per_loop. We can't log
        # in-between a session.run(), and we can only log after the
        # `iterations_per_loop` steps, so we can only approximate. If a user
        # requests to log every N steps, we actually want to roughly log every
        # N / `iterations_per_loop` steps to match the original intention.
        self._log_every_n_steps = (
            int(
                math.ceil(
                    float(self._log_every_n_steps) /
                    self._iterations_per_training_loop.value)))
      self._log_every_n_secs = None
    else:
      assert False, ('Invalid TPUConfig `iterations_per_loop` value. '
                     'Indicates a bug in `iterations_per_loop` '
                     'parsing.')

    # All properties passed to _InternalTPUContext are immutable.
    # pylint: disable=protected-access
    self._ctx = tpu_context._get_tpu_context(self._config, train_batch_size,
                                             eval_batch_size,
                                             predict_batch_size, use_tpu,
                                             eval_on_tpu, embedding_config_spec)

    self._export_to_cpu = export_to_cpu
    self._export_to_tpu = export_to_tpu

    if not (isinstance(export_saved_model_api_version,
                       ExportSavedModelApiVersion)
            or export_saved_model_api_version == 1
            or export_saved_model_api_version == 2):
      raise ValueError('export_saved_model_api_version should be 1 or 2; '
                       'got {}.'.format(
                           export_saved_model_api_version))
    self._export_saved_model_api_version = export_saved_model_api_version
    self._is_input_fn_invoked = None

    self._rendezvous = {}