def embed_tensorflow_computation()

in tensorflow_federated/python/core/impl/executors/eager_tf_executor.py [0:0]


def embed_tensorflow_computation(comp, type_spec=None, device=None):
  """Embeds a TensorFlow computation for use in the eager context.

  Args:
    comp: An instance of `pb.Computation`.
    type_spec: An optional `tff.Type` instance or something convertible to it.
    device: An optional `tf.config.LogicalDevice`.

  Returns:
    Either a one-argument or a zero-argument callable that executes the
    computation in eager mode.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
  # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py`
  # since it deals exclusively with eager mode. Incubate here, and potentially
  # move there, once stable.

  py_typecheck.check_type(comp, pb.Computation)
  comp = _ensure_comp_runtime_compatible(comp)
  comp_type = type_serialization.deserialize_type(comp.type)
  type_spec = computation_types.to_type(type_spec)
  if type_spec is not None:
    if not type_spec.is_equivalent_to(comp_type):
      raise TypeError('Expected a computation of type {}, got {}.'.format(
          type_spec, comp_type))
  else:
    type_spec = comp_type
  # TODO(b/156302055): Currently, TF will raise on any function returning a
  # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove
  # this gating when we can.
  must_pin_function_to_cpu = type_analysis.contains(type_spec.result,
                                                    lambda t: t.is_sequence())
  which_computation = comp.WhichOneof('computation')
  if which_computation != 'tensorflow':
    unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto(
        comp)
    raise TypeError('Expected a TensorFlow computation, found {}.'.format(
        unexpected_building_block))

  if type_spec.is_function():
    param_type = type_spec.parameter
    result_type = type_spec.result
  else:
    param_type = None
    result_type = type_spec

  wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu,
                                               param_type, device)
  param_fns = []
  if param_type is not None:
    for spec in structure.flatten(type_spec.parameter):
      if spec.is_tensor():
        param_fns.append(lambda x: x)
      else:
        py_typecheck.check_type(spec, computation_types.SequenceType)
        param_fns.append(tf.data.experimental.to_variant)

  result_fns = []
  for spec in structure.flatten(result_type):
    if spec.is_tensor():
      result_fns.append(lambda x: x)
    else:
      py_typecheck.check_type(spec, computation_types.SequenceType)
      tf_structure = type_conversions.type_to_tf_structure(spec.element)

      def fn(x, tf_structure=tf_structure):
        return tf.data.experimental.from_variant(x, tf_structure)

      result_fns.append(fn)

  ops = wrapped_fn.graph.get_operations()

  eager_cleanup_ops = []
  destroy_before_invocation = []
  for op in ops:
    if op.type == 'HashTableV2':
      eager_cleanup_ops += op.outputs
  if eager_cleanup_ops:
    for resource in wrapped_fn.prune(feeds={}, fetches=eager_cleanup_ops)():
      destroy_before_invocation.append(resource)

  lazy_cleanup_ops = []
  destroy_after_invocation = []
  for op in ops:
    if op.type == 'VarHandleOp':
      lazy_cleanup_ops += op.outputs
  if lazy_cleanup_ops:
    for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)():
      destroy_after_invocation.append(resource)

  def fn_to_return(arg,
                   param_fns=tuple(param_fns),
                   result_fns=tuple(result_fns),
                   result_type=result_type,
                   wrapped_fn=wrapped_fn,
                   destroy_before=tuple(destroy_before_invocation),
                   destroy_after=tuple(destroy_after_invocation)):
    # This double-function pattern works around python late binding, forcing the
    # variables to bind eagerly.
    return _call_embedded_tf(
        arg=arg,
        param_fns=param_fns,
        result_fns=result_fns,
        result_type=result_type,
        wrapped_fn=wrapped_fn,
        destroy_before_invocation=destroy_before,
        destroy_after_invocation=destroy_after)

  # pylint: disable=function-redefined
  if must_pin_function_to_cpu:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device('cpu'):
        return old_fn_to_return(x)
  elif device is not None:
    old_fn_to_return = fn_to_return

    def fn_to_return(x):
      with tf.device(device.name):
        return old_fn_to_return(x)

  # pylint: enable=function-redefined

  if param_type is not None:
    return lambda arg: fn_to_return(arg)  # pylint: disable=unnecessary-lambda
  else:
    return lambda: fn_to_return(None)