in tensorflow_federated/python/core/impl/utils/tensorflow_utils.py [0:0]
def assemble_result_from_graph(type_spec, binding, output_map):
"""Assembles a result stamped into a `tf.Graph` given type signature/binding.
This method does roughly the opposite of `capture_result_from_graph`, in that
whereas `capture_result_from_graph` starts with a single structured object
made up of tensors and computes its type and bindings, this method starts
with the type/bindings and constructs a structured object made up of tensors.
Args:
type_spec: The type signature of the result to assemble, an instance of
`types.Type` or something convertible to it.
binding: The binding that relates the type signature to names of tensors in
the graph, an instance of `pb.TensorFlow.Binding`.
output_map: The mapping from tensor names that appear in the binding to
actual stamped tensors (possibly renamed during import).
Returns:
The assembled result, a Python object that is composed of tensors, possibly
nested within Python structures such as anonymous tuples.
Raises:
TypeError: If the argument or any of its parts are of an uexpected type.
ValueError: If the arguments are invalid or inconsistent witch other, e.g.,
the type and binding don't match, or the tensor is not found in the map.
"""
type_spec = computation_types.to_type(type_spec)
py_typecheck.check_type(type_spec, computation_types.Type)
py_typecheck.check_type(binding, pb.TensorFlow.Binding)
py_typecheck.check_type(output_map, dict)
for k, v in output_map.items():
py_typecheck.check_type(k, str)
if not tf.is_tensor(v):
raise TypeError(
'Element with key {} in the output map is {}, not a tensor.'.format(
k, py_typecheck.type_string(type(v))))
binding_oneof = binding.WhichOneof('binding')
if type_spec.is_tensor():
if binding_oneof != 'tensor':
raise ValueError(
'Expected a tensor binding, found {}.'.format(binding_oneof))
elif binding.tensor.tensor_name not in output_map:
raise ValueError('Tensor named {} not found in the output map.'.format(
binding.tensor.tensor_name))
else:
tensor_name = binding.tensor.tensor_name
tensor = output_map[tensor_name]
try:
type_analysis.check_type(tensor, type_spec)
except TypeError as te:
raise ValueError(
f'Type mismatch loading graph result tensor {tensor} '
f'(named "{tensor_name}").\n'
'This may have been caused by a use of `tf.set_shape`.\n'
'Prefer usage of `tf.ensure_shape` to `tf.set_shape`.') from te
return tensor
elif type_spec.is_struct():
if binding_oneof != 'struct':
raise ValueError(
'Expected a struct binding, found {}.'.format(binding_oneof))
else:
type_elements = structure.to_elements(type_spec)
if len(binding.struct.element) != len(type_elements):
raise ValueError(
'Mismatching tuple sizes in type ({}) and binding ({}).'.format(
len(type_elements), len(binding.struct.element)))
result_elements = []
for (element_name,
element_type), element_binding in zip(type_elements,
binding.struct.element):
element_object = assemble_result_from_graph(element_type,
element_binding, output_map)
result_elements.append((element_name, element_object))
if type_spec.python_container is None:
return structure.Struct(result_elements)
container_type = type_spec.python_container
if (py_typecheck.is_named_tuple(container_type) or
py_typecheck.is_attrs(container_type)):
return container_type(**dict(result_elements))
return container_type(result_elements)
elif type_spec.is_sequence():
if binding_oneof != 'sequence':
raise ValueError(
'Expected a sequence binding, found {}.'.format(binding_oneof))
else:
sequence_oneof = binding.sequence.WhichOneof('binding')
if sequence_oneof == 'variant_tensor_name':
variant_tensor = output_map[binding.sequence.variant_tensor_name]
return make_dataset_from_variant_tensor(variant_tensor,
type_spec.element)
else:
raise ValueError(
'Unsupported sequence binding \'{}\'.'.format(sequence_oneof))
else:
raise ValueError('Unsupported type \'{}\'.'.format(type_spec))