def generate_per_host_v2_enqueue_ops_fn_for_host()

in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]


def generate_per_host_v2_enqueue_ops_fn_for_host(ctx, input_fn,
                                                 inputs_structure_recorder,
                                                 device, host_id,
                                                 invocation_index):
  """Generates infeed enqueue ops for per-host input_fn on a single host."""
  captured_infeed_queue = _CapturedObject()
  dataset_initializer = None

  with tf.compat.v1.device(device):
    user_context = tpu_context.TPUContext(
        internal_ctx=ctx,
        input_device=device,
        invocation_index=invocation_index,
        host_id=host_id)
    inputs = _Inputs.from_input_fn(input_fn(user_context))

    is_dataset = inputs.is_dataset
    if not is_dataset:
      raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
                      'input pipeline configuration.')

    # Be aware that when num_cores_per_replica > num_cores_per_host,
    # ctx.num_of_replicas_per_host is 0.
    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
      inputs = _InputsWithStoppingSignals(
          dataset=inputs.dataset,
          batch_size=ctx.batch_size_for_input_fn,
          add_padding=True,
          num_invocations_per_step=max(1, ctx.num_of_replicas_per_host))

    dataset_initializer = inputs.dataset_initializer()

    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)

    def device_function_impl(shard_id):
      if ctx.device_assignment is not None:
        # Find the replica_id of the host's logical core 0.
        # The current host_id is guaranteed to contain the logical core 0,
        # even when num_cores_per_replica > num_cores_per_host -- the function
        # caller makes sure that this host_id will must be receiving data (calls
        # input_fn).
        replica_id = ctx.device_assignment.lookup_replicas(
            task_id=host_id, logical_core=0)[shard_id]
        return ctx.tpu_host_placement_function(replica_id=replica_id)
      else:
        return None

  def enqueue_ops_fn():
    """Generates the per_host enqueue ops."""
    control_deps = []
    per_host_sharded_inputs = []
    enqueue_datas_list = []
    # Be aware that when num_cores_per_replica > num_cores_per_host,
    # ctx.num_of_replicas_per_host is 0.
    num_replicas_per_host = max(1, ctx.num_of_replicas_per_host)
    cached_signals = None
    with tf.compat.v1.device(device):
      if not inputs.is_dataset:
        raise TypeError('`input_fn` must return a `Dataset` for this mode.')
      for host in range(num_replicas_per_host):
        # Use control dependencies to ensure a deterministic ordering.
        if ctx.allow_per_host_v2_parallel_get_next:
          features, labels = inputs.features_and_labels()  # Calls get_next()
        with tf.control_dependencies(control_deps):
          if not ctx.allow_per_host_v2_parallel_get_next:
            features, labels = inputs.features_and_labels()  # Calls get_next()
          signals = inputs.signals()

          # All the replicas share the replica 0's stopping signal.
          # This avoids inconsistent state among different model replcias.
          if cached_signals:
            signals['stopping'] = cached_signals['stopping']
          else:
            cached_signals = signals

        features, labels, enqueue_data = (
            _tpu_estimator_embedding.split_inputs(ctx, features, labels))
        if len(enqueue_data) != 1:
          raise RuntimeError(('Missing or extra enqueue_data for host {}. '
                              'len(enqueue_data) = {}.').format(
                                 host, len(enqueue_data)))
        enqueue_datas_list.append(enqueue_data[0])

        inputs_structure_recorder.validate_and_record_structure(
            features, labels)
        flattened_inputs = (
            inputs_structure_recorder.flatten_features_and_labels(
                features, labels, signals))
        control_deps.extend(flattened_inputs)
        per_host_sharded_inputs.append(flattened_inputs)

      if inputs_structure_recorder.flattened_input_dims:
        input_partition_dims = inputs_structure_recorder.flattened_input_dims
        if signals:
          input_partition_dims += [None] * len(signals)
        # pylint: disable=protected-access
        infeed_queue = tpu_feed._PartitionedInfeedQueue(
            number_of_tuple_elements=len(per_host_sharded_inputs[0]),
            host_id=host_id,
            input_partition_dims=input_partition_dims,
            device_assignment=ctx.device_assignment)
        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
            per_host_sharded_inputs)
      else:
        infeed_queue = tpu_feed.InfeedQueue(
            number_of_tuple_elements=len(per_host_sharded_inputs[0]))
        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
            per_host_sharded_inputs,
            tpu_ordinal_function=tpu_ordinal_function_impl,
            placement_function=device_function_impl)

      captured_infeed_queue.capture(infeed_queue)

    if ctx.embedding_config:
      per_host_enqueue_ops.extend(
          ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
              enqueue_datas_list))

    if signals is None:
      return per_host_enqueue_ops
    else:
      return {
          'ops': per_host_enqueue_ops,
          'signals': signals,
      }

  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer