def _build_tf_computations_for_gather()

in tensorflow_federated/python/learning/framework/encoding_utils.py [0:0]


def _build_tf_computations_for_gather(state_type, value_type, encoders):
  """Utility for creating tf_computations for encoded sum and mean.

  This method maps a collection of GatherEncoder objects to partial computations
  for encoding a collection of values jointly, and adds a logic for computing
  the number of summands in decode_before_sum, once for the entire collection,
  not on a per-value basis.

  Args:
    state_type: A `tff.Type` describing the collection of states handled by
      `encoders`.
    value_type: A `tff.Type` describing the collection of values to be encoded
      by `encoders`.
    encoders: A collection of `GatherEncoder` objects.

  Returns:
    A `_NestGatherEncoder` namedtuple holding the relevant tf_computations.
  """

  @computations.tf_computation(state_type)
  def get_params_fn(state):
    params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
                                      encoders, state)
    encode_params = _slice(encoders, params, 0)
    decode_before_sum_params = _slice(encoders, params, 1)
    decode_after_sum_params = _slice(encoders, params, 2)
    return encode_params, decode_before_sum_params, decode_after_sum_params

  encode_params_type = get_params_fn.type_signature.result[0]
  decode_before_sum_params_type = get_params_fn.type_signature.result[1]
  decode_after_sum_params_type = get_params_fn.type_signature.result[2]

  # TODO(b/139844355): Get rid of decode_before_sum_params.
  # We pass decode_before_sum_params to the encode method, because TFF currently
  # does not have a mechanism to make a tff.SERVER placed value available inside
  # of intrinsics.federated_aggregate - in production, this could mean an
  # intermediary aggregator node. So currently, we send the params to clients,
  # and ask them to send them back as part of the encoded structure.
  @computations.tf_computation(value_type, encode_params_type,
                               decode_before_sum_params_type)
  def encode_fn(x, encode_params, decode_before_sum_params):
    encoded_structure = tree.map_structure_up_to(
        encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params)
    encoded_x = _slice(encoders, encoded_structure, 0)
    state_update_tensors = _slice(encoders, encoded_structure, 1)
    return encoded_x, decode_before_sum_params, state_update_tensors

  state_update_tensors_type = encode_fn.type_signature.result[2]

  # This is not a @computations.tf_computation because it will be used below
  # when bulding the computations.tf_computations that will compose a
  # intrinsics.federated_aggregate...
  # @tf.function
  def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
    part_decoded_x = tree.map_structure_up_to(
        encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
        encoded_x, decode_before_sum_params)
    one = tf.constant((1,), tf.int32)
    return part_decoded_x, one

  # ...however, result type is needed to build the subsequent tf_compuations.
  @computations.tf_computation(encode_fn.type_signature.result[0:2])
  def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
    return decode_before_sum_tf_function(encoded_x, decode_before_sum_params)

  part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
  del tmp_decode_before_sum_fn  # Only needed for result type.

  @computations.tf_computation(part_decoded_x_type,
                               decode_after_sum_params_type)
  def decode_after_sum_fn(summed_values, decode_after_sum_params):
    part_decoded_aggregated_x, num_summands = summed_values
    return tree.map_structure_up_to(
        encoders,
        lambda e, x, params: e.decode_after_sum(x, params, num_summands),
        encoders, part_decoded_aggregated_x, decode_after_sum_params)

  @computations.tf_computation(state_type, state_update_tensors_type)
  def update_state_fn(state, state_update_tensors):
    return tree.map_structure_up_to(encoders,
                                    lambda e, *args: e.update_state(*args),
                                    encoders, state, state_update_tensors)

  # Computations for intrinsics.federated_aggregate.
  @computations.tf_computation
  def zero_fn():
    values = tf.nest.map_structure(
        lambda s: tf.zeros(s.shape, s.dtype),
        type_conversions.type_to_tf_tensor_specs(part_decoded_x_type))
    state_update_tensors = tf.nest.map_structure(
        lambda s: tf.zeros(s.shape, s.dtype),
        type_conversions.type_to_tf_tensor_specs(state_update_tensors_type))
    return _accumulator_value(values, state_update_tensors)

  accumulator_type = zero_fn.type_signature.result
  state_update_aggregation_modes = tf.nest.map_structure(
      lambda e: tuple(e.state_update_aggregation_modes), encoders)

  @computations.tf_computation(accumulator_type,
                               encode_fn.type_signature.result)
  def accumulate_fn(acc, encoded_x):
    """Internal accumulate function."""
    value, params, state_update_tensors = encoded_x
    part_decoded_value = decode_before_sum_tf_function(value, params)
    new_values = tf.nest.map_structure(tf.add, acc['values'],
                                       part_decoded_value)
    new_state_update_tensors = tf.nest.map_structure(
        _accmulate_state_update_tensor, acc['state_update_tensors'],
        state_update_tensors, state_update_aggregation_modes)
    return _accumulator_value(new_values, new_state_update_tensors)

  @computations.tf_computation(accumulator_type, accumulator_type)
  def merge_fn(acc1, acc2):
    new_values = tf.nest.map_structure(tf.add, acc1['values'], acc2['values'])
    new_state_update_tensors = tf.nest.map_structure(
        _accmulate_state_update_tensor, acc1['state_update_tensors'],
        acc2['state_update_tensors'], state_update_aggregation_modes)
    return _accumulator_value(new_values, new_state_update_tensors)

  @computations.tf_computation(accumulator_type)
  def report_fn(acc):
    return acc

  return _NestGatherEncoder(
      get_params_fn=get_params_fn,
      encode_fn=encode_fn,
      decode_after_sum_fn=decode_after_sum_fn,
      update_state_fn=update_state_fn,
      zero_fn=zero_fn,
      accumulate_fn=accumulate_fn,
      merge_fn=merge_fn,
      report_fn=report_fn)