in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]
def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
num_hosts):
"""Generates infeed enqueue ops for one input_fn on all the hosts."""
captured_infeed_queue = _CapturedObject()
dataset_initializer = None
device_0 = ctx.tpu_host_placement_function(host_id=0)
with tf.compat.v1.device(device_0):
user_context = tpu_context.TPUContext(
internal_ctx=ctx, input_device=device_0, invocation_index=0, host_id=0)
inputs = _Inputs.from_input_fn(input_fn(user_context))
is_dataset = inputs.is_dataset
if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
if not is_dataset:
raise TypeError(
'For mode PREDICT, `input_fn` must return `Dataset` instead of '
'`features` and `labels`.')
inputs = _InputsWithStoppingSignals(
dataset=inputs.dataset,
batch_size=ctx.batch_size_for_input_fn,
add_padding=True)
if is_dataset:
dataset_initializer = inputs.dataset_initializer()
num_replicas_per_host = ctx.num_of_replicas_per_host
def tpu_ordinal_function_impl(shard_id):
if ctx.device_assignment:
return ctx.device_assignment.tpu_ordinal(replica=shard_id)
else:
return shard_id % num_replicas_per_host
def device_function_impl(shard_id):
# shard_id ranges from 0 to num_of_replicas_per_host - 1.
# A shard is a replica inside a host.
# In broadcast mode (generate_broadcast_enqueue_ops_fn), the enqueue ops
# are always executed on the first host. Thus shard_id equals to replica_id.
return ctx.tpu_host_placement_function(replica_id=shard_id)
def enqueue_ops_fn():
"""Generates enqueue ops for all the hosts."""
broadcasted_inputs = []
flattened_inputs = None # Cache result from input_fn.
signals = None
num_replicas = ctx.num_replicas
core_id = 0
for host_id in xrange(num_hosts):
with tf.compat.v1.device(
ctx.tpu_host_placement_function(host_id=host_id)):
for _ in xrange(ctx.num_of_replicas_per_host):
# Note: input_fn is only called once at host 0 for the first replica.
# The features and labels returned from that invocation are
# broadcasted to other replicas(including the replicas on other
# hosts).
if flattened_inputs is None:
features, labels = inputs.features_and_labels() # Calls get_next()
signals = inputs.signals()
inputs_structure_recorder.validate_and_record_structure(
features, labels)
flattened_inputs = (
inputs_structure_recorder.flatten_features_and_labels(
features, labels, signals))
if (ctx.config.tpu_config.eval_training_input_configuration is
tpu_config.InputPipelineConfig.SLICED):
input_slices = [
tf.split(x, num_replicas) for x in flattened_inputs
]
if (ctx.config.tpu_config.eval_training_input_configuration is
tpu_config.InputPipelineConfig.SLICED):
# for each core, slice out the flattened_inputs for each core.
broadcasted_inputs.append([x[core_id] for x in input_slices])
core_id += 1
else:
broadcasted_inputs.append(flattened_inputs)
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(broadcasted_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
enqueue_ops = infeed_queue.generate_enqueue_ops(
broadcasted_inputs,
tpu_ordinal_function=tpu_ordinal_function_impl,
placement_function=device_function_impl)
if signals is None:
return enqueue_ops
else:
return {
'ops': enqueue_ops,
'signals': signals,
}
return enqueue_ops_fn, captured_infeed_queue, dataset_initializer