def generate_broadcast_enqueue_ops_fn()

in tensorflow_estimator/python/estimator/tpu/tpu_estimator.py [0:0]


def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
                                      num_hosts):
  """Generates infeed enqueue ops for one input_fn on all the hosts."""
  captured_infeed_queue = _CapturedObject()
  dataset_initializer = None
  device_0 = ctx.tpu_host_placement_function(host_id=0)
  with tf.compat.v1.device(device_0):
    user_context = tpu_context.TPUContext(
        internal_ctx=ctx, input_device=device_0, invocation_index=0, host_id=0)
    inputs = _Inputs.from_input_fn(input_fn(user_context))

    is_dataset = inputs.is_dataset
    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
      if not is_dataset:
        raise TypeError(
            'For mode PREDICT, `input_fn` must return `Dataset` instead of '
            '`features` and `labels`.')

      inputs = _InputsWithStoppingSignals(
          dataset=inputs.dataset,
          batch_size=ctx.batch_size_for_input_fn,
          add_padding=True)

    if is_dataset:
      dataset_initializer = inputs.dataset_initializer()
    num_replicas_per_host = ctx.num_of_replicas_per_host

  def tpu_ordinal_function_impl(shard_id):
    if ctx.device_assignment:
      return ctx.device_assignment.tpu_ordinal(replica=shard_id)
    else:
      return shard_id % num_replicas_per_host

  def device_function_impl(shard_id):
    # shard_id ranges from 0 to num_of_replicas_per_host - 1.
    # A shard is a replica inside a host.
    # In broadcast mode (generate_broadcast_enqueue_ops_fn), the enqueue ops
    # are always executed on the first host. Thus shard_id equals to replica_id.
    return ctx.tpu_host_placement_function(replica_id=shard_id)

  def enqueue_ops_fn():
    """Generates enqueue ops for all the hosts."""
    broadcasted_inputs = []
    flattened_inputs = None  # Cache result from input_fn.
    signals = None
    num_replicas = ctx.num_replicas
    core_id = 0
    for host_id in xrange(num_hosts):
      with tf.compat.v1.device(
          ctx.tpu_host_placement_function(host_id=host_id)):
        for _ in xrange(ctx.num_of_replicas_per_host):
          # Note: input_fn is only called once at host 0 for the first replica.
          # The features and labels returned from that invocation are
          # broadcasted to other replicas(including the replicas on other
          # hosts).
          if flattened_inputs is None:
            features, labels = inputs.features_and_labels()  # Calls get_next()
            signals = inputs.signals()

            inputs_structure_recorder.validate_and_record_structure(
                features, labels)
            flattened_inputs = (
                inputs_structure_recorder.flatten_features_and_labels(
                    features, labels, signals))
            if (ctx.config.tpu_config.eval_training_input_configuration is
                tpu_config.InputPipelineConfig.SLICED):
              input_slices = [
                  tf.split(x, num_replicas) for x in flattened_inputs
              ]
          if (ctx.config.tpu_config.eval_training_input_configuration is
              tpu_config.InputPipelineConfig.SLICED):
            # for each core, slice out the flattened_inputs for each core.
            broadcasted_inputs.append([x[core_id] for x in input_slices])
            core_id += 1
          else:
            broadcasted_inputs.append(flattened_inputs)

    infeed_queue = tpu_feed.InfeedQueue(
        number_of_tuple_elements=len(broadcasted_inputs[0]))
    captured_infeed_queue.capture(infeed_queue)
    enqueue_ops = infeed_queue.generate_enqueue_ops(
        broadcasted_inputs,
        tpu_ordinal_function=tpu_ordinal_function_impl,
        placement_function=device_function_impl)

    if signals is None:
      return enqueue_ops
    else:
      return {
          'ops': enqueue_ops,
          'signals': signals,
      }

  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer