def CreateTpuEnqueueOps()

in lingvo/core/base_input_generator.py [0:0]


  def CreateTpuEnqueueOps(self,
                          job_name=None,
                          skip_enqueue=False,
                          benchmark_only=False):
    """Create the host-side enqueue ops.

    This should be called in an outer non-TPU context.

    Args:
      job_name: the name of the job on which the enqueue operations run.
      skip_enqueue: if True, only create the tpu queues, but skip the enqueue
        call. To be used in eager mode to setup tpu queues.
      benchmark_only: If true, don't wire it up to the TPU infeed.
    """
    if not py_utils.IsEagerMode():
      assert not self._tpu_queues, (
          'CreateTpuEnqueueOps should only be called once.')
    self._tpu_queues = []
    self._per_host_batches = []
    self._per_host_emb_batches = []
    # A list of lists, where the [i][j] element is the j-th passthrought batch
    # of the i-th task. Each task will have more than one passthrought batch iff
    # sharded infeed is used.
    self._per_host_passthrough_batches = []
    p = self.params
    num_tpu_hosts = self.cluster.num_tpu_hosts
    num_cores_per_host = self.cluster.total_worker_devices // num_tpu_hosts
    tf.logging.info(
        'CreateTpuEnqueueOps num_splits_per_client={} '
        'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'
        .format(self.cluster.num_splits_per_client,
                self.cluster.num_devices_per_split, num_tpu_hosts,
                p.use_per_host_infeed))

    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    if p.use_per_core_infeed:
      if (not p.use_per_host_infeed) or p.use_partitioned_infeed_queue:
        raise ValueError('use_per_core_infeed need to have use_per_host_infeed '
                         'but not use_partitioned_infeed_queue.')
      if p.num_partitions is None or p.num_partitions <= 1:
        raise ValueError('use_per_core_infeed needs num_partitions > 1.')
    if (self.cluster.num_devices_per_split > num_cores_per_host and
        (p.use_per_host_infeed and not p.use_per_core_infeed)):
      tf.logging.fatal('Doesn\'t support per host infeed mode when '
                       'num_devices_per_split({}) > num_cores_per_host({}).'
                       'Each host must be able to accommodate >= 1 split when '
                       'using per_host_infeed.'.format(
                           self.cluster.num_devices_per_split,
                           num_cores_per_host))

    shards = self.tpu_number_of_shards
    tf.logging.info('shards {}'.format(shards))

    input_ops_list = []
    cpu_passthrough_keys = self.GetCpuPassthroughKeys()

    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1
    tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)
    host_devices = self.cluster.ListDevices(self.cluster.job_spec).flatten()
    if p.use_per_host_infeed and num_infeed_hosts != len(host_devices):
      raise ValueError(
          f'Configuration mismatch, number of infeed hosts {num_infeed_hosts} '
          f'does not match available devices {host_devices}.')
    if p.use_per_host_infeed:
      task_ids = list(range(num_infeed_hosts))
    elif self.do_eval:
      # Run eval input generation on the last device
      task_ids = [len(host_devices) - 1]
    else:
      # Run train input generation on the first device
      task_ids = [0]
    for task_id in task_ids:
      host_device = host_devices[task_id]
      if cpu_passthrough_keys and (
          '/task:{}/device:CPU:0'.format(task_id) not in host_device):
        raise ValueError(
            f'CPU passthrough configuration mismatch, device {host_device} '
            f'does not match task id {task_id}.')
      with tf.device(host_device), cluster.InfeedContextScope(
          infeed_host_index=task_id, num_infeed_hosts=num_infeed_hosts):
        batch = self.GetPreprocessedInputBatch()
        if not isinstance(batch, (list, tuple)):
          batch = [batch]

        cur_passthrough_batches = []
        for i in range(len(batch)):
          b = batch[i]
          assert isinstance(b, py_utils.NestedMap)
          # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
          # Note that when MultiTaskData is used, bucket_keys will be at the
          # second level of the dictionary.
          b = b.FilterKeyVal(lambda k, _: not k.endswith('bucket_keys'))

          # Split out any keys that are meant for CPU passthrough only.
          cur_passthrough_batches.append(
              b.FilterKeyVal(lambda k, _: k in cpu_passthrough_keys))
          b = b.FilterKeyVal(lambda k, _: k not in cpu_passthrough_keys)
          batch[i] = b
          if i > 0:
            # If the input batch is already sharded, check that the shards are
            # compatible with each other.
            assert py_utils.IsCompatible(b, batch[0])
        self._per_host_passthrough_batches.append(cur_passthrough_batches)
        tf.logging.info('CPU passthrough keys: %s', cpu_passthrough_keys)

        if p.filter_sparse_tensors:
          # Make a copy of this host's input batch, then filter out any
          # SparseTensor features. This way, SparseTensor features are not fed
          # into the TPU InfeedQueue (and only to TPUEmbedding).
          # TODO(jeffreyzhao): Hack, come up with better solution.
          # Ideally we would like users to override
          # CreateTpuEmbeddingEnqueueOps() to modify the input batch
          # and remove fields they don't want to enqueue onto TPU.
          # However, the TPUEmbedding singleton and TPU embedding enqueue ops
          # are currently constructed after CreateTpuEnqueueOps() is called.
          emb_batch = []
          new_batch = []
          for i, b in enumerate(batch):
            emb_batch.append(
                b.Filter(lambda v: isinstance(v, tf.sparse.SparseTensor)))
            new_batch.append(
                b.Filter(lambda v: not isinstance(v, tf.sparse.SparseTensor)))
          self._per_host_emb_batches.append(emb_batch)
          batch = new_batch

        self._batch_nm_types = batch[0]
        tf.logging.info(
            'host_device: %s, batch: %r', host_device,
            py_utils.Transform(lambda x: (x.shape, x.dtype), batch[0]))
        self._per_host_batches.append(batch)

        if benchmark_only:
          continue

        for b in batch:
          for k, x in b.FlattenItems():
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
          # TODO(cwhipkey): if it's a string (or other type not supported on
          # TPU), drop it from feeding and on the other end add in an op that
          # fails if used.
        shapes = batch[0].Transform(lambda x: x.shape).Flatten()
        dtypes = batch[0].Transform(lambda x: x.dtype).Flatten()

        if p.use_partitioned_infeed_queue:
          device_assignment = py_utils.GetTpuDeviceAssignment(job_name)

          host_device = device_assignment.host_device(
              replica=0, job=tf.flags.FLAGS.tf_master)
          host_id = int(host_device.split('/task:')[1].split('/device:')[0])
          tf.logging.info('host_id: {} host_device: {}'.format(
              host_id, host_device))
          q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
              number_of_tuple_elements=len(dtypes),
              device_assignment=device_assignment,
              host_id=host_id,
              input_partition_dims=[
                  [p.num_partitions] + [1] * (len(s) - 1) for s in shapes
              ],
              tuple_types=dtypes,
              tuple_shapes=shapes)
        else:
          if p.use_per_core_infeed:
            q = tpu_feed.InfeedQueue(
                tuple_types=dtypes,
                tuple_shapes=shapes,
                number_of_partitions=p.num_partitions)
          elif len(batch) > 1:
            # When the input batch is sharded, the unsharded dtypes and shapes
            # will be determined later by the generate_enqueue_ops() call.
            q = tpu_feed.InfeedQueue(
                number_of_tuple_elements=len(batch[0].Flatten()))
          else:
            q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes)
          assert shards is not None
          q.set_number_of_shards(shards)

        self._tpu_queues.append(q)

        if not skip_enqueue:
          if p.use_partitioned_infeed_queue:
            assert len(batch) == 1
            input_ops = q.generate_enqueue_ops([batch[0].Flatten()])
          elif p.use_per_host_infeed:
            def HostPlacementFunction(host_device, x):
              del x  # Unused.
              return host_device

            if len(batch) > 1:
              # In this case, the `shard_index_in_host` argument of
              # `_PerHostInfeedTPUOrdinalFunction` is the index of a sharded
              # batch in the `batch` list.
              input_ops = q.generate_enqueue_ops(
                  [b.Flatten() for b in batch],
                  placement_function=functools.partial(HostPlacementFunction,
                                                       host_device),
                  tpu_ordinal_function=functools.partial(
                      _PerHostInfeedTPUOrdinalFunction, p.use_per_core_infeed,
                      task_id))
            else:
              input_ops = q.split_inputs_and_generate_enqueue_ops(
                  batch[0].Flatten(),
                  placement_function=functools.partial(HostPlacementFunction,
                                                       host_device),
                  tpu_ordinal_function=functools.partial(
                      _PerHostInfeedTPUOrdinalFunction, p.use_per_core_infeed,
                      task_id))
          else:
            assert len(batch) == 1
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                batch[0].Flatten(),
                device_assignment=py_utils.GetTpuDeviceAssignment(job_name))
          input_ops_list += input_ops

    if benchmark_only:
      grouped_infeed_op = tf.group(*self._per_host_batches)
    else:
      tf.logging.info('input_ops_list %s', input_ops_list)
      grouped_infeed_op = tf.group(*input_ops_list)

    self._tpu_infeed_op = []
    for _ in range(p.tpu_infeed_parallelism):
      self._tpu_infeed_op.append(grouped_infeed_op)