in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]
def generate_per_host_v2_enqueue_ops_fn_for_host(ctx, input_fn,
inputs_structure_recorder,
device, host_id,
invocation_index):
"""Generates infeed enqueue ops for per-host input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
dataset_initializer = None
with tf.compat.v1.device(device):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=device,
invocation_index=invocation_index,
host_id=host_id)
inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if not is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
'input pipeline configuration.')
# Be aware that when num_cores_per_replica > num_cores_per_host,
# ctx.num_of_replicas_per_host is 0.
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
inputs = _InputsWithStoppingSignals(
dataset=inputs.dataset,
batch_size=ctx.batch_size_for_input_fn,
add_padding=True,
num_invocations_per_step=max(1, ctx.num_of_replicas_per_host))
dataset_initializer = inputs.dataset_initializer()
tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def device_function_impl(shard_id):
if ctx.device_assignment is not None:
# Find the replica_id of the host's logical core 0.
# The current host_id is guaranteed to contain the logical core 0,
# even when num_cores_per_replica > num_cores_per_host -- the function
# caller makes sure that this host_id will must be receiving data (calls
# input_fn).
replica_id = ctx.device_assignment.lookup_replicas(
task_id=host_id, logical_core=0)[shard_id]
return ctx.tpu_host_placement_function(replica_id=replica_id)
else:
return None
def enqueue_ops_fn():
"""Generates the per_host enqueue ops."""
control_deps = []
per_host_sharded_inputs = []
enqueue_datas_list = []
# Be aware that when num_cores_per_replica > num_cores_per_host,
# ctx.num_of_replicas_per_host is 0.
num_replicas_per_host = max(1, ctx.num_of_replicas_per_host)
cached_signals = None
with tf.compat.v1.device(device):
if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
for host in range(num_replicas_per_host):
# Use control dependencies to ensure a deterministic ordering.
if ctx.allow_per_host_v2_parallel_get_next:
features, labels = inputs.features_and_labels() # Calls get_next()
with tf.control_dependencies(control_deps):
if not ctx.allow_per_host_v2_parallel_get_next:
features, labels = inputs.features_and_labels() # Calls get_next()
signals = inputs.signals()
# All the replicas share the replica 0's stopping signal.
# This avoids inconsistent state among different model replcias.
if cached_signals:
signals['stopping'] = cached_signals['stopping']
else:
cached_signals = signals
features, labels, enqueue_data = (
_tpu_estimator_embedding.split_inputs(ctx, features, labels))
if len(enqueue_data) != 1:
raise RuntimeError(('Missing or extra enqueue_data for host {}. '
'len(enqueue_data) = {}.').format(
host, len(enqueue_data)))
enqueue_datas_list.append(enqueue_data[0])
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
control_deps.extend(flattened_inputs)
per_host_sharded_inputs.append(flattened_inputs)
if inputs_structure_recorder.flattened_input_dims:
input_partition_dims = inputs_structure_recorder.flattened_input_dims
if signals:
input_partition_dims += [None] * len(signals)
# pylint: disable=protected-access
infeed_queue = tpu_feed._PartitionedInfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]),
host_id=host_id,
input_partition_dims=input_partition_dims,
device_assignment=ctx.device_assignment)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs)
else:
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs,
tpu_ordinal_function=tpu_ordinal_function_impl,
placement_function=device_function_impl)
captured_infeed_queue.capture(infeed_queue)
if ctx.embedding_config:
per_host_enqueue_ops.extend(
ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
enqueue_datas_list))
if signals is None:
return per_host_enqueue_ops
else:
return {
'ops': per_host_enqueue_ops,
'signals': signals,
}
return enqueue_ops_fn, captured_infeed_queue, dataset_initializer