in tensorflow_federated/python/core/impl/executors/eager_tf_executor.py [0:0]
def embed_tensorflow_computation(comp, type_spec=None, device=None):
"""Embeds a TensorFlow computation for use in the eager context.
Args:
comp: An instance of `pb.Computation`.
type_spec: An optional `tff.Type` instance or something convertible to it.
device: An optional `tf.config.LogicalDevice`.
Returns:
Either a one-argument or a zero-argument callable that executes the
computation in eager mode.
Raises:
TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
TensorFlow computation.
"""
# TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
# since it deals exclusively with eager mode. Incubate here, and potentially
# move there, once stable.
py_typecheck.check_type(comp, pb.Computation)
comp = _ensure_comp_runtime_compatible(comp)
comp_type = type_serialization.deserialize_type(comp.type)
type_spec = computation_types.to_type(type_spec)
if type_spec is not None:
if not type_spec.is_equivalent_to(comp_type):
raise TypeError('Expected a computation of type {}, got {}.'.format(
type_spec, comp_type))
else:
type_spec = comp_type
# TODO(b/156302055): Currently, TF will raise on any function returning a
# `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
# this gating when we can.
must_pin_function_to_cpu = type_analysis.contains(type_spec.result,
lambda t: t.is_sequence())
which_computation = comp.WhichOneof('computation')
if which_computation != 'tensorflow':
unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto(
comp)
raise TypeError('Expected a TensorFlow computation, found {}.'.format(
unexpected_building_block))
if type_spec.is_function():
param_type = type_spec.parameter
result_type = type_spec.result
else:
param_type = None
result_type = type_spec
wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu,
param_type, device)
param_fns = []
if param_type is not None:
for spec in structure.flatten(type_spec.parameter):
if spec.is_tensor():
param_fns.append(lambda x: x)
else:
py_typecheck.check_type(spec, computation_types.SequenceType)
param_fns.append(tf.data.experimental.to_variant)
result_fns = []
for spec in structure.flatten(result_type):
if spec.is_tensor():
result_fns.append(lambda x: x)
else:
py_typecheck.check_type(spec, computation_types.SequenceType)
tf_structure = type_conversions.type_to_tf_structure(spec.element)
def fn(x, tf_structure=tf_structure):
return tf.data.experimental.from_variant(x, tf_structure)
result_fns.append(fn)
ops = wrapped_fn.graph.get_operations()
eager_cleanup_ops = []
destroy_before_invocation = []
for op in ops:
if op.type == 'HashTableV2':
eager_cleanup_ops += op.outputs
if eager_cleanup_ops:
for resource in wrapped_fn.prune(feeds={}, fetches=eager_cleanup_ops)():
destroy_before_invocation.append(resource)
lazy_cleanup_ops = []
destroy_after_invocation = []
for op in ops:
if op.type == 'VarHandleOp':
lazy_cleanup_ops += op.outputs
if lazy_cleanup_ops:
for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)():
destroy_after_invocation.append(resource)
def fn_to_return(arg,
param_fns=tuple(param_fns),
result_fns=tuple(result_fns),
result_type=result_type,
wrapped_fn=wrapped_fn,
destroy_before=tuple(destroy_before_invocation),
destroy_after=tuple(destroy_after_invocation)):
# This double-function pattern works around python late binding, forcing the
# variables to bind eagerly.
return _call_embedded_tf(
arg=arg,
param_fns=param_fns,
result_fns=result_fns,
result_type=result_type,
wrapped_fn=wrapped_fn,
destroy_before_invocation=destroy_before,
destroy_after_invocation=destroy_after)
# pylint: disable=function-redefined
if must_pin_function_to_cpu:
old_fn_to_return = fn_to_return
def fn_to_return(x):
with tf.device('cpu'):
return old_fn_to_return(x)
elif device is not None:
old_fn_to_return = fn_to_return
def fn_to_return(x):
with tf.device(device.name):
return old_fn_to_return(x)
# pylint: enable=function-redefined
if param_type is not None:
return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda
else:
return lambda: fn_to_return(None)