def _invoke_input_fn_and_record_structure()

in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]


  def _invoke_input_fn_and_record_structure(self):
    """Deploys the input pipeline and record input structure."""
    enqueue_ops = []
    infeed_queues = []
    all_dataset_initializers = []
    num_hosts = self._ctx.num_hosts
    tpu_host_placement_fn = self._ctx.tpu_host_placement_function

    run_infeed_loop_on_coordinator = True

    if self._sharded_per_core:
      # Per-Core input pipeline deployment.
      # Invoke input pipeline for each core and placed on the corresponding
      # host.
      for host_id in range(num_hosts):
        host_device = tpu_host_placement_fn(host_id=host_id)
        with tf.compat.v1.device(host_device):
          with ops.name_scope('input_pipeline_task%d' % (host_id)):
            enqueue_ops_fn, captured_infeed_queue = (
                generate_per_core_enqueue_ops_fn_for_host(
                    self._ctx, self._input_fn, self._inputs_structure_recorder,
                    host_device, host_id))

            if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
              run_infeed_loop_on_coordinator = False
              enqueue_ops.append(
                  _wrap_computation_in_while_loop(
                      device=host_device, op_fn=enqueue_ops_fn))
            else:
              enqueue_ops.append(enqueue_ops_fn())
            # Infeed_queue_getter must be called after enqueue_ops_fn is called.
            infeed_queues.append(captured_infeed_queue.get())

    elif self._ctx.is_input_broadcast_with_iterators():
      # Only calls input_fn in host 0.
      host_device = tpu_host_placement_fn(host_id=0)
      enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
          generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
                                            self._inputs_structure_recorder,
                                            num_hosts))
      if dataset_initializer:
        all_dataset_initializers.append(dataset_initializer)
        run_infeed_loop_on_coordinator = False
        wrap_fn = (
            _wrap_computation_in_while_loop
            if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
            _wrap_computation_in_while_loop_with_stopping_signals)
        enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
      else:
        enqueue_ops.append(enqueue_ops_fn())
      infeed_queues.append(captured_infeed_queue.get())

    else:
      # This branch handles two senarios:
      #       num_cores_per_replica > num_cores_per_host
      #   and num_cores_per_replica <= num_cores_per_host
      # First, get the set of host_ids, by iterating replicas.
      # We only want and will get the set of *unique* host_ids
      # *that will call input_fn*. For each replica, we only call the input_fn
      # from the CPU host that contains logical core 0.

      # Use a list here to ensure deterministic order.
      host_id_with_invocation_id_pair = []

      if not self._ctx.is_replica_across_hosts():
        for host_id in range(num_hosts):
          invocation_index = host_id
          host_id_with_invocation_id_pair.append((host_id, invocation_index))
      else:
        for replica_id in xrange(self._ctx.num_replicas):
          invocation_index = replica_id
          host_device, _ = self._ctx.device_for_replica(replica_id)
          # TODO(lehou): Get host_id in a better way.
          host_id = int(host_device.split('/task:')[1].split('/device:')[0])
          host_id_with_invocation_id_pair.append((host_id, invocation_index))

      for (host_id, invocation_index) in host_id_with_invocation_id_pair:
        host_device = tpu_host_placement_fn(host_id=host_id)
        with tf.compat.v1.device(host_device):
          with ops.name_scope('input_pipeline_task%d' % (host_id)):
            if self._ctx.is_input_per_host_with_iterators():
              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
                  generate_per_host_v2_enqueue_ops_fn_for_host(
                      self._ctx, self._input_fn,
                      self._inputs_structure_recorder, host_device, host_id,
                      invocation_index))
            else:
              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
                  generate_per_host_enqueue_ops_fn_for_host(
                      self._ctx, self._input_fn,
                      self._inputs_structure_recorder, self._batch_axis,
                      host_device, host_id))

            # NOTE(xiejw): We dispatch here based on the return type of the
            # users `input_fn`.
            #
            # 1. If input_fn returns a Dataset instance, we initialize the
            # iterator outside of tf.while_loop, and call the iterator.get_next
            # inside tf.while_loop.  This should be always safe.
            #
            # 2. If input_fn returns (features, labels), it is too late to wrap
            # them inside tf.while_loop, as resource initialization cannot be
            # handled in TF control flow properly. In this case, we will use
            # python loop to enqueue the data into TPU system.  This may be
            # slow compared to the previous case.
            if dataset_initializer:
              all_dataset_initializers.append(dataset_initializer)
              run_infeed_loop_on_coordinator = False
              wrap_fn = (
                  _wrap_computation_in_while_loop
                  if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
                  _wrap_computation_in_while_loop_with_stopping_signals)
              enqueue_ops.append(
                  wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
            else:
              enqueue_ops.append(enqueue_ops_fn())
            infeed_queues.append(captured_infeed_queue.get())

    # infeed_queue is used to generate dequeue ops. The only thing it uses for
    # dequeue is dtypes and types. So, any one can be used. Here, grab the
    # first one.
    self._infeed_queue = infeed_queues[0]
    return enqueue_ops, [
        util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers)
    ], run_infeed_loop_on_coordinator