in lingvo/core/base_input_generator.py [0:0]
def CreateTpuEnqueueOps(self,
job_name=None,
skip_enqueue=False,
benchmark_only=False):
"""Create the host-side enqueue ops.
This should be called in an outer non-TPU context.
Args:
job_name: the name of the job on which the enqueue operations run.
skip_enqueue: if True, only create the tpu queues, but skip the enqueue
call. To be used in eager mode to setup tpu queues.
benchmark_only: If true, don't wire it up to the TPU infeed.
"""
if not py_utils.IsEagerMode():
assert not self._tpu_queues, (
'CreateTpuEnqueueOps should only be called once.')
self._tpu_queues = []
self._per_host_batches = []
self._per_host_emb_batches = []
# A list of lists, where the [i][j] element is the j-th passthrought batch
# of the i-th task. Each task will have more than one passthrought batch iff
# sharded infeed is used.
self._per_host_passthrough_batches = []
p = self.params
num_tpu_hosts = self.cluster.num_tpu_hosts
num_cores_per_host = self.cluster.total_worker_devices // num_tpu_hosts
tf.logging.info(
'CreateTpuEnqueueOps num_splits_per_client={} '
'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'
.format(self.cluster.num_splits_per_client,
self.cluster.num_devices_per_split, num_tpu_hosts,
p.use_per_host_infeed))
assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
if p.use_per_core_infeed:
if (not p.use_per_host_infeed) or p.use_partitioned_infeed_queue:
raise ValueError('use_per_core_infeed need to have use_per_host_infeed '
'but not use_partitioned_infeed_queue.')
if p.num_partitions is None or p.num_partitions <= 1:
raise ValueError('use_per_core_infeed needs num_partitions > 1.')
if (self.cluster.num_devices_per_split > num_cores_per_host and
(p.use_per_host_infeed and not p.use_per_core_infeed)):
tf.logging.fatal('Doesn\'t support per host infeed mode when '
'num_devices_per_split({}) > num_cores_per_host({}).'
'Each host must be able to accommodate >= 1 split when '
'using per_host_infeed.'.format(
self.cluster.num_devices_per_split,
num_cores_per_host))
shards = self.tpu_number_of_shards
tf.logging.info('shards {}'.format(shards))
input_ops_list = []
cpu_passthrough_keys = self.GetCpuPassthroughKeys()
num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1
tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts)
host_devices = self.cluster.ListDevices(self.cluster.job_spec).flatten()
if p.use_per_host_infeed and num_infeed_hosts != len(host_devices):
raise ValueError(
f'Configuration mismatch, number of infeed hosts {num_infeed_hosts} '
f'does not match available devices {host_devices}.')
if p.use_per_host_infeed:
task_ids = list(range(num_infeed_hosts))
elif self.do_eval:
# Run eval input generation on the last device
task_ids = [len(host_devices) - 1]
else:
# Run train input generation on the first device
task_ids = [0]
for task_id in task_ids:
host_device = host_devices[task_id]
if cpu_passthrough_keys and (
'/task:{}/device:CPU:0'.format(task_id) not in host_device):
raise ValueError(
f'CPU passthrough configuration mismatch, device {host_device} '
f'does not match task id {task_id}.')
with tf.device(host_device), cluster.InfeedContextScope(
infeed_host_index=task_id, num_infeed_hosts=num_infeed_hosts):
batch = self.GetPreprocessedInputBatch()
if not isinstance(batch, (list, tuple)):
batch = [batch]
cur_passthrough_batches = []
for i in range(len(batch)):
b = batch[i]
assert isinstance(b, py_utils.NestedMap)
# Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
# Note that when MultiTaskData is used, bucket_keys will be at the
# second level of the dictionary.
b = b.FilterKeyVal(lambda k, _: not k.endswith('bucket_keys'))
# Split out any keys that are meant for CPU passthrough only.
cur_passthrough_batches.append(
b.FilterKeyVal(lambda k, _: k in cpu_passthrough_keys))
b = b.FilterKeyVal(lambda k, _: k not in cpu_passthrough_keys)
batch[i] = b
if i > 0:
# If the input batch is already sharded, check that the shards are
# compatible with each other.
assert py_utils.IsCompatible(b, batch[0])
self._per_host_passthrough_batches.append(cur_passthrough_batches)
tf.logging.info('CPU passthrough keys: %s', cpu_passthrough_keys)
if p.filter_sparse_tensors:
# Make a copy of this host's input batch, then filter out any
# SparseTensor features. This way, SparseTensor features are not fed
# into the TPU InfeedQueue (and only to TPUEmbedding).
# TODO(jeffreyzhao): Hack, come up with better solution.
# Ideally we would like users to override
# CreateTpuEmbeddingEnqueueOps() to modify the input batch
# and remove fields they don't want to enqueue onto TPU.
# However, the TPUEmbedding singleton and TPU embedding enqueue ops
# are currently constructed after CreateTpuEnqueueOps() is called.
emb_batch = []
new_batch = []
for i, b in enumerate(batch):
emb_batch.append(
b.Filter(lambda v: isinstance(v, tf.sparse.SparseTensor)))
new_batch.append(
b.Filter(lambda v: not isinstance(v, tf.sparse.SparseTensor)))
self._per_host_emb_batches.append(emb_batch)
batch = new_batch
self._batch_nm_types = batch[0]
tf.logging.info(
'host_device: %s, batch: %r', host_device,
py_utils.Transform(lambda x: (x.shape, x.dtype), batch[0]))
self._per_host_batches.append(batch)
if benchmark_only:
continue
for b in batch:
for k, x in b.FlattenItems():
assert x.shape.is_fully_defined(), (
'Shape must be fully defined: %s: %s' % (k, x))
# TODO(cwhipkey): if it's a string (or other type not supported on
# TPU), drop it from feeding and on the other end add in an op that
# fails if used.
shapes = batch[0].Transform(lambda x: x.shape).Flatten()
dtypes = batch[0].Transform(lambda x: x.dtype).Flatten()
if p.use_partitioned_infeed_queue:
device_assignment = py_utils.GetTpuDeviceAssignment(job_name)
host_device = device_assignment.host_device(
replica=0, job=tf.flags.FLAGS.tf_master)
host_id = int(host_device.split('/task:')[1].split('/device:')[0])
tf.logging.info('host_id: {} host_device: {}'.format(
host_id, host_device))
q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access
number_of_tuple_elements=len(dtypes),
device_assignment=device_assignment,
host_id=host_id,
input_partition_dims=[
[p.num_partitions] + [1] * (len(s) - 1) for s in shapes
],
tuple_types=dtypes,
tuple_shapes=shapes)
else:
if p.use_per_core_infeed:
q = tpu_feed.InfeedQueue(
tuple_types=dtypes,
tuple_shapes=shapes,
number_of_partitions=p.num_partitions)
elif len(batch) > 1:
# When the input batch is sharded, the unsharded dtypes and shapes
# will be determined later by the generate_enqueue_ops() call.
q = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(batch[0].Flatten()))
else:
q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes)
assert shards is not None
q.set_number_of_shards(shards)
self._tpu_queues.append(q)
if not skip_enqueue:
if p.use_partitioned_infeed_queue:
assert len(batch) == 1
input_ops = q.generate_enqueue_ops([batch[0].Flatten()])
elif p.use_per_host_infeed:
def HostPlacementFunction(host_device, x):
del x # Unused.
return host_device
if len(batch) > 1:
# In this case, the `shard_index_in_host` argument of
# `_PerHostInfeedTPUOrdinalFunction` is the index of a sharded
# batch in the `batch` list.
input_ops = q.generate_enqueue_ops(
[b.Flatten() for b in batch],
placement_function=functools.partial(HostPlacementFunction,
host_device),
tpu_ordinal_function=functools.partial(
_PerHostInfeedTPUOrdinalFunction, p.use_per_core_infeed,
task_id))
else:
input_ops = q.split_inputs_and_generate_enqueue_ops(
batch[0].Flatten(),
placement_function=functools.partial(HostPlacementFunction,
host_device),
tpu_ordinal_function=functools.partial(
_PerHostInfeedTPUOrdinalFunction, p.use_per_core_infeed,
task_id))
else:
assert len(batch) == 1
input_ops = q.split_inputs_and_generate_enqueue_ops(
batch[0].Flatten(),
device_assignment=py_utils.GetTpuDeviceAssignment(job_name))
input_ops_list += input_ops
if benchmark_only:
grouped_infeed_op = tf.group(*self._per_host_batches)
else:
tf.logging.info('input_ops_list %s', input_ops_list)
grouped_infeed_op = tf.group(*input_ops_list)
self._tpu_infeed_op = []
for _ in range(p.tpu_infeed_parallelism):
self._tpu_infeed_op.append(grouped_infeed_op)