in tensorflow/tensorflow/python/tpu/tpu.py [0:0]
def split_compile_and_replicate(computation,
inputs=None,
infeed_queue=None,
device_assignment=None,
name=None,
use_tpu=True,
maximum_shapes=None):
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
and execute output tensor. In the generated graph the compile op feeds into
the execute op and no additional compilation is incurred when running the
compile op before the execute op. The compile op returns additional
information about the compilation but does not return the compiled program.
Args:
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors. Note that passing an
N-dimension list of compatible values will result in a N-dimension list of
scalar tensors rather than a single Rank-N tensors. If you need different
behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the
mapping between logical cores in the computation with physical cores in
the TPU topology. Uses a default device assignment if `None`. The
`DeviceAssignment` may be omitted if each replica of the computation uses
only one core, and there is either only one replica, or the number of
replicas is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
backends. Currently, only supports a default placement (computation is
placed on GPU if one is available, and on CPU if not).
maximum_shapes: A nested structure of tf.TensorShape representing the shape
to which the respective component of each input element in each replica
should be padded. Any unknown dimensions (e.g.
tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`.
Returns:
A list of lists with the first list corresponding to the compile op and the
second a list of output tensors, indexed by `[replica_num][output_num]`.
Raises:
ValueError: If all replicas do not have equal numbers of input tensors.
ValueError: If the number of inputs per replica does not match
the number of formal parameters to `computation`.
ValueError: If the static `inputs` dimensions don't match with the values
given in `maximum_shapes`.
ValueError: If the structure of inputs per replica does not match
the structure of `maximum_shapes`.
"""
del name
inputs = [[]] if inputs is None else inputs
metadata_kwargs = {}
if device_assignment is not None:
# Turn the Numpy array into a flattened list so we can pass it as an
# operator attribute.
metadata_kwargs = {
"topology":
device_assignment.topology.serialized(),
"device_assignment":
device_assignment.core_assignment.flatten().tolist()
}
# TODO(phawkins): remove this case after the forward compatibility window
# expires on 2018-10-5.
if api_compat.forward_compatible(2018, 10, 5):
metadata_kwargs["num_cores_per_replica"] = (
device_assignment.num_cores_per_replica)
else:
metadata_kwargs["computation_shape"] = [
device_assignment.num_cores_per_replica
]
# This entry is used for enabling automatic outside compilation.
metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
if ((not isinstance(inputs, list)) or
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
raise TypeError("tpu.replicate() inputs must be a list of lists/tuples")
num_replicas = len(inputs)
# No replicas? Nothing to do.
if num_replicas == 0:
return []
# Checks all replicas have the same structure.
for i in xrange(1, num_replicas):
nest.assert_same_structure(inputs[0], inputs[i])
# Flatten inputs.
flat_inputs = [
nest.flatten(per_replica_input) for per_replica_input in inputs
]
# Converts inputs to Tensors.
flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs]
# Verifies that all replicas have matching numbers and types of inputs
flat_input_types = [x.dtype for x in flat_inputs[0]]
input_arity = len(inputs[0])
flat_input_arity = len(flat_input_types)
for i in range(num_replicas):
if len(inputs[i]) != input_arity:
raise ValueError("Replicas must have the same number of inputs. "
"Replica 0 had {} inputs, replica {} had {} "
"inputs.".format(input_arity, i, len(inputs[i])))
types = [x.dtype for x in flat_inputs[i]]
if types != flat_input_types:
raise ValueError("Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format(
flat_input_types, i, types))
arg_error = xla.check_function_argument_count(
computation, input_arity, infeed_queue)
if arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
"You specified %d inputs: %s, but the computation needs %s" % (
input_arity, str([i.name for i in inputs[0]]), arg_error))
else:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
"You specified %d inputs: %s and %d additional inputs from infeed,"
" but the computation needs %s" % (input_arity, str(
[i.name
for i in inputs[0]]), infeed_queue.number_of_tuple_elements,
arg_error))
if maximum_shapes:
if infeed_queue:
raise ValueError(
"Dynamic input shapes are not supported with infeed queues")
# Make sure maximum_shapes has the same structure as inputs.
nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
# Flatten padded shapes.
flat_maximum_shapes = nest.flatten(maximum_shapes)
flat_maximum_shapes = [
tensor_shape.TensorShape(s) for s in flat_maximum_shapes
]
flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes)
serialized_padding_maps = []
for padding_map in padding_maps:
serialized_padding_maps.append(padding_map.SerializeToString())
metadata_kwargs["padding_map"] = serialized_padding_maps
metadata_kwargs["step_marker_location"] = getattr(
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
graph = ops.get_default_graph()
# Fan-in: Builds a TPUReplicatedInput node for each input.
flat_replicated_inputs = []
for i in range(0, len(flat_inputs[0])):
replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)]
flat_replicated_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
if isinstance(graph, func_graph.FuncGraph):
# When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
# object. If both outside graph and this function have a TPU cluster,
# they will have the same cluster name and it will cause problems (because
# we lower functional ops in Tensorflow 2.0). Append function name to
# 'cluster_name' to avoid cluster name collision.
cluster_name = graph.unique_name("cluster_" + graph.name)
else:
cluster_name = graph.unique_name("cluster")
pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
context = TPUReplicateContext(
name=cluster_name, num_replicas=num_replicas, pivot=pivot)
try:
context.Enter()
metadata = tpu_ops.tpu_replicate_metadata(
num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
# Add identity ops so even unused inputs are "consumed" by the
# computation. This is to avoid orphaned TPUReplicatedInput nodes.
# TODO(phawkins): consider instead pruning unused TPUReplicatedInput
# and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
flat_replicated_inputs = [
array_ops.identity(x, name="replicated_input_{}".format(i))
for i, x in enumerate(flat_replicated_inputs)
]
for i in flat_replicated_inputs:
# pylint: disable=protected-access
# Add an attribute to the identity node so that they could be removed in
# encapsulate TPU computation pass if unused. However we don't remove
# inputs when dynamic padding is enabled.
# TODO(rxsang): Use other ways except argument index in padding_map so
# outside compilation can work with dynamic padding correctly.
if maximum_shapes is None:
i.op._set_attr("_tpu_input_identity",
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
# Unflatten the computation inputs to match original input structure.
computation_inputs = nest.pack_sequence_as(
structure=inputs[0],
flat_sequence=flat_replicated_inputs[:flat_input_arity])
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
if infeed_queue is not None:
infeed_queue.set_number_of_shards(num_replicas)
for t in infeed_queue.generate_dequeue_op():
computation_inputs.append(t)
# Only resource variables work inside a TPU computation, so turn on
# resource variables for the computation.
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
saved_custom_getter = vscope.custom_getter
def custom_getter(getter, name, *args, **kwargs):
"""Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
if partitioner is not None:
kwargs["partitioner"] = None
logging.warning(
"Partitioned variables are not supported on TPU. Got "
"`partitioner` that is {} for variable {}. "
"Setting `partitioner` to `None`."
.format(partitioner, name))
if saved_custom_getter is None:
return getter(name, *args, **kwargs)
else:
return saved_custom_getter(getter, name, *args, **kwargs)
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
vscope.set_custom_getter(saved_custom_getter)
outputs_is_flat = xla.is_flat(outputs)
if outputs_is_flat:
output_tensors, control_deps = _postprocess_flat_outputs(outputs)
else:
output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
# tensor_tracer imports tpu.py. Local import to tensor_tracer to avoid
# import-cycle
# pylint: disable=g-import-not-at-top
from tensorflow.python.tpu import tensor_tracer
# pylint: enable=g-import-not-at-top
if tensor_tracer.TensorTracer.is_enabled():
tt = tensor_tracer.TensorTracer()
output_tensors = tt.trace_tpu(ops.get_default_graph(),
output_tensors, control_deps,
num_replicas)
context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
host_compute_core = context.HostComputeCore()
if host_compute_core:
attr_value = attr_value_pb2.AttrValue()
attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
with ops.control_dependencies([metadata]):
if use_tpu:
compile_status = tpu_ops.tpu_compilation_result()
op = compile_status.op
attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
else:
compile_status = control_flow_ops.no_op(name="compilation_status")
if not output_tensors:
# Returns a list of NoOps dependent on the replication Op, indexed by
# [replica_num].
return [
compile_status,
[
control_flow_ops.group(control_deps, name="shard_%d" % i)
for i in range(num_replicas)
]
]
# Fan-out: Builds a TPUReplicatedOutput node for each output.
replicated_outputs = [[] for i in xrange(num_replicas)]
for i, t in enumerate(output_tensors):
# Fan-out: Builds a TPUReplicatedOutput node for each output.
ys = tpu_ops.tpu_replicated_output(
t, num_replicas, name="output{}".format(i))
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
with ops.control_dependencies(control_deps):
for replica in xrange(num_replicas):
replicated_outputs[replica].append(
array_ops.identity(
ys[replica], name="output_%d_shard_%d" % (i, replica)))
if not outputs_is_flat:
replicated_outputs = [
nest.pack_sequence_as(outputs, replica_outs)
for replica_outs in replicated_outputs
]
return [compile_status, replicated_outputs]