tensorflow_federated/python/aggregators/encoded.py [183:237]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def get_params_fn(state):
    params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
                                      encoders, state)
    encode_params = _slice(encoders, params, 0)
    decode_before_sum_params = _slice(encoders, params, 1)
    decode_after_sum_params = _slice(encoders, params, 2)
    return encode_params, decode_before_sum_params, decode_after_sum_params

  encode_params_type = get_params_fn.type_signature.result[0]
  decode_before_sum_params_type = get_params_fn.type_signature.result[1]
  decode_after_sum_params_type = get_params_fn.type_signature.result[2]

  # TODO(b/139844355): Get rid of decode_before_sum_params.
  # We pass decode_before_sum_params to the encode method, because TFF currently
  # does not have a mechanism to make a tff.SERVER placed value available inside
  # of intrinsics.federated_aggregate - in production, this could mean an
  # intermediary aggregator node. So currently, we send the params to clients,
  # and ask them to send them back as part of the encoded structure.
  @computations.tf_computation(value_type, encode_params_type,
                               decode_before_sum_params_type)
  def encode_fn(x, encode_params, decode_before_sum_params):
    encoded_structure = tree.map_structure_up_to(
        encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params)
    encoded_x = _slice(encoders, encoded_structure, 0)
    state_update_tensors = _slice(encoders, encoded_structure, 1)
    return encoded_x, decode_before_sum_params, state_update_tensors

  state_update_tensors_type = encode_fn.type_signature.result[2]

  # This is not a @computations.tf_computation because it will be used below
  # when bulding the computations.tf_computations that will compose a
  # intrinsics.federated_aggregate...
  def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
    part_decoded_x = tree.map_structure_up_to(
        encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
        encoded_x, decode_before_sum_params)
    one = tf.constant((1,), tf.int32)
    return part_decoded_x, one

  # ...however, result type is needed to build the subsequent tf_compuations.
  @computations.tf_computation(encode_fn.type_signature.result[0:2])
  def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
    return decode_before_sum_tf_function(encoded_x, decode_before_sum_params)

  part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
  del tmp_decode_before_sum_fn  # Only needed for result type.

  @computations.tf_computation(part_decoded_x_type,
                               decode_after_sum_params_type)
  def decode_after_sum_fn(summed_values, decode_after_sum_params):
    part_decoded_aggregated_x, num_summands = summed_values
    return tree.map_structure_up_to(
        encoders,
        lambda e, x, params: e.decode_after_sum(x, params, num_summands),
        encoders, part_decoded_aggregated_x, decode_after_sum_params)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



tensorflow_federated/python/learning/framework/encoding_utils.py [202:257]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  def get_params_fn(state):
    params = tree.map_structure_up_to(encoders, lambda e, s: e.get_params(s),
                                      encoders, state)
    encode_params = _slice(encoders, params, 0)
    decode_before_sum_params = _slice(encoders, params, 1)
    decode_after_sum_params = _slice(encoders, params, 2)
    return encode_params, decode_before_sum_params, decode_after_sum_params

  encode_params_type = get_params_fn.type_signature.result[0]
  decode_before_sum_params_type = get_params_fn.type_signature.result[1]
  decode_after_sum_params_type = get_params_fn.type_signature.result[2]

  # TODO(b/139844355): Get rid of decode_before_sum_params.
  # We pass decode_before_sum_params to the encode method, because TFF currently
  # does not have a mechanism to make a tff.SERVER placed value available inside
  # of intrinsics.federated_aggregate - in production, this could mean an
  # intermediary aggregator node. So currently, we send the params to clients,
  # and ask them to send them back as part of the encoded structure.
  @computations.tf_computation(value_type, encode_params_type,
                               decode_before_sum_params_type)
  def encode_fn(x, encode_params, decode_before_sum_params):
    encoded_structure = tree.map_structure_up_to(
        encoders, lambda e, *args: e.encode(*args), encoders, x, encode_params)
    encoded_x = _slice(encoders, encoded_structure, 0)
    state_update_tensors = _slice(encoders, encoded_structure, 1)
    return encoded_x, decode_before_sum_params, state_update_tensors

  state_update_tensors_type = encode_fn.type_signature.result[2]

  # This is not a @computations.tf_computation because it will be used below
  # when bulding the computations.tf_computations that will compose a
  # intrinsics.federated_aggregate...
  # @tf.function
  def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
    part_decoded_x = tree.map_structure_up_to(
        encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
        encoded_x, decode_before_sum_params)
    one = tf.constant((1,), tf.int32)
    return part_decoded_x, one

  # ...however, result type is needed to build the subsequent tf_compuations.
  @computations.tf_computation(encode_fn.type_signature.result[0:2])
  def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
    return decode_before_sum_tf_function(encoded_x, decode_before_sum_params)

  part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
  del tmp_decode_before_sum_fn  # Only needed for result type.

  @computations.tf_computation(part_decoded_x_type,
                               decode_after_sum_params_type)
  def decode_after_sum_fn(summed_values, decode_after_sum_params):
    part_decoded_aggregated_x, num_summands = summed_values
    return tree.map_structure_up_to(
        encoders,
        lambda e, x, params: e.decode_after_sum(x, params, num_summands),
        encoders, part_decoded_aggregated_x, decode_after_sum_params)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



