def assemble_result_from_graph()

in tensorflow_federated/python/core/impl/utils/tensorflow_utils.py [0:0]


def assemble_result_from_graph(type_spec, binding, output_map):
  """Assembles a result stamped into a `tf.Graph` given type signature/binding.

  This method does roughly the opposite of `capture_result_from_graph`, in that
  whereas `capture_result_from_graph` starts with a single structured object
  made up of tensors and computes its type and bindings, this method starts
  with the type/bindings and constructs a structured object made up of tensors.

  Args:
    type_spec: The type signature of the result to assemble, an instance of
      `types.Type` or something convertible to it.
    binding: The binding that relates the type signature to names of tensors in
      the graph, an instance of `pb.TensorFlow.Binding`.
    output_map: The mapping from tensor names that appear in the binding to
      actual stamped tensors (possibly renamed during import).

  Returns:
    The assembled result, a Python object that is composed of tensors, possibly
    nested within Python structures such as anonymous tuples.

  Raises:
    TypeError: If the argument or any of its parts are of an uexpected type.
    ValueError: If the arguments are invalid or inconsistent witch other, e.g.,
      the type and binding don't match, or the tensor is not found in the map.
  """
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  py_typecheck.check_type(binding, pb.TensorFlow.Binding)
  py_typecheck.check_type(output_map, dict)
  for k, v in output_map.items():
    py_typecheck.check_type(k, str)
    if not tf.is_tensor(v):
      raise TypeError(
          'Element with key {} in the output map is {}, not a tensor.'.format(
              k, py_typecheck.type_string(type(v))))

  binding_oneof = binding.WhichOneof('binding')
  if type_spec.is_tensor():
    if binding_oneof != 'tensor':
      raise ValueError(
          'Expected a tensor binding, found {}.'.format(binding_oneof))
    elif binding.tensor.tensor_name not in output_map:
      raise ValueError('Tensor named {} not found in the output map.'.format(
          binding.tensor.tensor_name))
    else:
      tensor_name = binding.tensor.tensor_name
      tensor = output_map[tensor_name]
      try:
        type_analysis.check_type(tensor, type_spec)
      except TypeError as te:
        raise ValueError(
            f'Type mismatch loading graph result tensor {tensor} '
            f'(named "{tensor_name}").\n'
            'This may have been caused by a use of `tf.set_shape`.\n'
            'Prefer usage of `tf.ensure_shape` to `tf.set_shape`.') from te
      return tensor
  elif type_spec.is_struct():
    if binding_oneof != 'struct':
      raise ValueError(
          'Expected a struct binding, found {}.'.format(binding_oneof))
    else:
      type_elements = structure.to_elements(type_spec)
      if len(binding.struct.element) != len(type_elements):
        raise ValueError(
            'Mismatching tuple sizes in type ({}) and binding ({}).'.format(
                len(type_elements), len(binding.struct.element)))
      result_elements = []
      for (element_name,
           element_type), element_binding in zip(type_elements,
                                                 binding.struct.element):
        element_object = assemble_result_from_graph(element_type,
                                                    element_binding, output_map)
        result_elements.append((element_name, element_object))
      if type_spec.python_container is None:
        return structure.Struct(result_elements)
      container_type = type_spec.python_container
      if (py_typecheck.is_named_tuple(container_type) or
          py_typecheck.is_attrs(container_type)):
        return container_type(**dict(result_elements))
      return container_type(result_elements)
  elif type_spec.is_sequence():
    if binding_oneof != 'sequence':
      raise ValueError(
          'Expected a sequence binding, found {}.'.format(binding_oneof))
    else:
      sequence_oneof = binding.sequence.WhichOneof('binding')
      if sequence_oneof == 'variant_tensor_name':
        variant_tensor = output_map[binding.sequence.variant_tensor_name]
        return make_dataset_from_variant_tensor(variant_tensor,
                                                type_spec.element)
      else:
        raise ValueError(
            'Unsupported sequence binding \'{}\'.'.format(sequence_oneof))
  else:
    raise ValueError('Unsupported type \'{}\'.'.format(type_spec))