in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]
def _invoke_input_fn_and_record_structure(self):
"""Deploys the input pipeline and record input structure."""
enqueue_ops = []
infeed_queues = []
all_dataset_initializers = []
num_hosts = self._ctx.num_hosts
tpu_host_placement_fn = self._ctx.tpu_host_placement_function
run_infeed_loop_on_coordinator = True
if self._sharded_per_core:
# Per-Core input pipeline deployment.
# Invoke input pipeline for each core and placed on the corresponding
# host.
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with tf.compat.v1.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host(
self._ctx, self._input_fn, self._inputs_structure_recorder,
host_device, host_id))
if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
run_infeed_loop_on_coordinator = False
enqueue_ops.append(
_wrap_computation_in_while_loop(
device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(captured_infeed_queue.get())
elif self._ctx.is_input_broadcast_with_iterators():
# Only calls input_fn in host 0.
host_device = tpu_host_placement_fn(host_id=0)
enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
self._inputs_structure_recorder,
num_hosts))
if dataset_initializer:
all_dataset_initializers.append(dataset_initializer)
run_infeed_loop_on_coordinator = False
wrap_fn = (
_wrap_computation_in_while_loop
if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
_wrap_computation_in_while_loop_with_stopping_signals)
enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
else:
# This branch handles two senarios:
# num_cores_per_replica > num_cores_per_host
# and num_cores_per_replica <= num_cores_per_host
# First, get the set of host_ids, by iterating replicas.
# We only want and will get the set of *unique* host_ids
# *that will call input_fn*. For each replica, we only call the input_fn
# from the CPU host that contains logical core 0.
# Use a list here to ensure deterministic order.
host_id_with_invocation_id_pair = []
if not self._ctx.is_replica_across_hosts():
for host_id in range(num_hosts):
invocation_index = host_id
host_id_with_invocation_id_pair.append((host_id, invocation_index))
else:
for replica_id in xrange(self._ctx.num_replicas):
invocation_index = replica_id
host_device, _ = self._ctx.device_for_replica(replica_id)
# TODO(lehou): Get host_id in a better way.
host_id = int(host_device.split('/task:')[1].split('/device:')[0])
host_id_with_invocation_id_pair.append((host_id, invocation_index))
for (host_id, invocation_index) in host_id_with_invocation_id_pair:
host_device = tpu_host_placement_fn(host_id=host_id)
with tf.compat.v1.device(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
if self._ctx.is_input_per_host_with_iterators():
enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
generate_per_host_v2_enqueue_ops_fn_for_host(
self._ctx, self._input_fn,
self._inputs_structure_recorder, host_device, host_id,
invocation_index))
else:
enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
generate_per_host_enqueue_ops_fn_for_host(
self._ctx, self._input_fn,
self._inputs_structure_recorder, self._batch_axis,
host_device, host_id))
# NOTE(xiejw): We dispatch here based on the return type of the
# users `input_fn`.
#
# 1. If input_fn returns a Dataset instance, we initialize the
# iterator outside of tf.while_loop, and call the iterator.get_next
# inside tf.while_loop. This should be always safe.
#
# 2. If input_fn returns (features, labels), it is too late to wrap
# them inside tf.while_loop, as resource initialization cannot be
# handled in TF control flow properly. In this case, we will use
# python loop to enqueue the data into TPU system. This may be
# slow compared to the previous case.
if dataset_initializer:
all_dataset_initializers.append(dataset_initializer)
run_infeed_loop_on_coordinator = False
wrap_fn = (
_wrap_computation_in_while_loop
if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
_wrap_computation_in_while_loop_with_stopping_signals)
enqueue_ops.append(
wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
else:
enqueue_ops.append(enqueue_ops_fn())
infeed_queues.append(captured_infeed_queue.get())
# infeed_queue is used to generate dequeue ops. The only thing it uses for
# dequeue is dtypes and types. So, any one can be used. Here, grab the
# first one.
self._infeed_queue = infeed_queues[0]
return enqueue_ops, [
util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers)
], run_infeed_loop_on_coordinator