def create_hierarchical_histogram_aggregation_factory()

in tensorflow_federated/python/analytics/hierarchical_histogram/hierarchical_histogram_factory.py [0:0]


def create_hierarchical_histogram_aggregation_factory(
    num_bins: int,
    arity: int = 2,
    clip_mechanism: str = 'sub-sampling',
    max_records_per_user: int = 10,
    dp_mechanism: str = 'no-noise',
    noise_multiplier: float = 0.0,
    expected_clients_per_round: int = 10,
    bits: int = 22):
  """Creates hierarchical histogram aggregation factory.

  Hierarchical histogram factory is constructed by composing 3 aggregation
  factories.
  (1) The inner-most factory is `SumFactory`.
  (2) The middle factory is `DifferentiallyPrivateFactory` whose inner query is
      `TreeRangeSumQuery`. This factory 1) takes in a clipped histogram,
      constructs the hierarchical histogram and checks the norm bound of the
      hierarchical histogram at clients, 2) adds noise either at clients or at
      server according to `dp_mechanism`.
  (3) The outer-most factory is `HistogramClippingSumFactory` which clips the
      input histogram to bound each user's contribution.

  Args:
    num_bins: An `int` representing the input histogram size.
    arity: An `int` representing the branching factor of the tree. Defaults to
      2.
   clip_mechanism: A `str` representing the clipping mechanism. Currently
     supported mechanisms are
      - 'sub-sampling': (Default) Uniformly sample up to `max_records_per_user`
        records without replacement from the client dataset.
      - 'distinct': Uniquify client dataset and uniformly sample up to
        `max_records_per_user` records without replacement from it.
    max_records_per_user: An `int` representing the maximum of records each user
      can include in their local histogram. Defaults to 10.
    dp_mechanism: A `str` representing the differentially private mechanism to
      use. Currently supported mechanisms are
      - 'no-noise': (Default) Tree aggregation mechanism without noise.
      - 'central-gaussian': Tree aggregation with central Gaussian mechanism.
      - 'distributed-discrete-gaussian': Tree aggregation mechanism with
        distributed discrete Gaussian mechanism in "The Distributed Discrete
        Gaussian Mechanism for Federated Learning with Secure Aggregation. Peter
        Kairouz, Ziyu Liu, Thomas Steinke".
    noise_multiplier: A `float` specifying the noise multiplier (central noise
      stddev / L2 clip norm) for model updates. Only needed when `dp_mechanism`
      is not 'no-noise'. Defaults to 0.0.
    expected_clients_per_round: An `int` specifying the lower bound of the
      expected number of clients. Only needed when `dp_mechanism` is
      'distributed-discrete-gaussian. Defaults to 10.
    bits: A positive integer specifying the communication bit-width B (where
      2**B will be the field size for SecAgg operations). Only needed when
      `dp_mechanism` is 'distributed-discrete-gaussian'. Please read the below
      precautions carefully and set `bits` accordingly. Otherwise, unexpected
      overflow or accuracy degradation might happen.
      (1) Should be in the inclusive range [1, 22] to avoid overflow inside
      secure aggregation;
      (2) Should be at least as large as
      `log2(4 * sqrt(expected_clients_per_round)* noise_multiplier *
      l2_norm_bound + expected_clients_per_round * max_records_per_user) + 1`
      to avoid accuracy degradation caused by frequent modular clipping;
      (3) If the number of clients exceed `expected_clients_per_round`, overflow
      might happen.

  Returns:
    `tff.aggregators.UnweightedAggregationFactory`.

  Raises:
    TypeError: If arguments have the wrong type(s).
    ValueError: If arguments have invalid value(s).
  """
  _check_positive(num_bins, 'num_bins')
  _check_greater_equal(arity, 2, 'arity')
  _check_membership(clip_mechanism, clipping_factory.CLIP_MECHANISMS,
                    'clip_mechanism')
  _check_positive(max_records_per_user, 'max_records_per_user')
  _check_membership(dp_mechanism, DP_MECHANISMS, 'dp_mechanism')
  _check_non_negative(noise_multiplier, 'noise_multiplier')
  _check_positive(expected_clients_per_round, 'expected_clients_per_round')
  _check_in_range(bits, 'bits', 1, 22)

  # Converts `max_records_per_user` to the corresponding norm bound according to
  # the chosen `clip_mechanism` and `dp_mechanism`.
  if dp_mechanism in ['central-gaussian', 'distributed-discrete-gaussian']:
    if clip_mechanism == 'sub-sampling':
      l2_norm_bound = max_records_per_user * math.sqrt(
          _tree_depth(num_bins, arity))
    elif clip_mechanism == 'distinct':
      # The following code block converts `max_records_per_user` to L2 norm
      # bound of the hierarchical histogram layer by layer. For the bottom
      # layer with only 0s and at most `max_records_per_user` 1s, the L2 norm
      # bound is `sqrt(max_records_per_user)`. For the second layer from bottom,
      # the worst case is only 0s and `max_records_per_user/2` 2s. And so on
      # until the root node. Another natural L2 norm bound on each layer is
      # `max_records_per_user` so we take the minimum between the two bounds.
      square_l2_norm_bound = 0.
      square_layer_l2_norm_bound = max_records_per_user
      for _ in range(_tree_depth(num_bins, arity)):
        square_l2_norm_bound += min(max_records_per_user**2,
                                    square_layer_l2_norm_bound)
        square_layer_l2_norm_bound *= arity
      l2_norm_bound = math.sqrt(square_l2_norm_bound)

  # Build nested aggregtion factory from innermost to outermost.
  # 1. Sum factory. The most inner factory that sums the preprocessed records.
  # (1) If `dp_mechanism` is in `CENTRAL_DP_MECHANISMS` or
  #     `NO_NOISE_MECHANISMS`, should be `SumFactory`.
  if dp_mechanism in CENTRAL_DP_MECHANISMS + NO_NOISE_MECHANISMS:
    nested_factory = sum_factory.SumFactory()
  # (2) If `dp_mechanism` is in `DISTRIBUTED_DP_MECHANISMS`, should be
  #     `SecureSumFactory`. To preserve DP and avoid overflow, we have 4 modular
  #     clips from nesting two modular clip aggregators:
  #    #1. outer-client: clips to [-2**(bits-1), 2**(bits-1))
  #        Bounds the client values.
  #    #2. inner-client: clips to [0, 2**bits)
  #        Similar to applying a two's complement to the values such that
  #        frequent values (post-rotation) are now near 0 (representing small
  #        positives) and 2**bits (small negatives). 0 also always map to 0, and
  #        we do not require another explicit value range shift from
  #        [-2**(bits-1), 2**(bits-1)] to [0, 2**bits] to make sure that values
  #        are compatible with SecAgg's mod m = 2**bits. This can be reverted at
  #        #4.
  #    #3. inner-server: clips to [0, 2**bits)
  #        Ensures the aggregated value range does not grow by
  #        `log2(expected_clients_per_round)`.
  #        NOTE: If underlying SecAgg is implemented using the new
  #        `tff.federated_secure_modular_sum()` operator with the same
  #        modular clipping range, then this would correspond to a no-op.
  #    #4. outer-server: clips to [-2**(bits-1), 2**(bits-1))
  #        Keeps aggregated values centered near 0 out of the logical SecAgg
  #        black box for outer aggregators.
  elif dp_mechanism in DISTRIBUTED_DP_MECHANISMS:
    # TODO(b/196312838): Please add scaling to the distributed case once we have
    # a stable guideline for setting scaling factor to improve performance and
    # avoid overflow. The below test is to make sure that modular clipping
    # happens with small probability so the accuracy of the result won't be
    # harmed. However, if the number of clients exceeds
    # `expected_clients_per_round`, overflow still might happen. It is the
    # caller's responsibility to carefully choose `bits` according to system
    # details to avoid overflow or performance degradation.
    if bits < math.log2(4 * math.sqrt(expected_clients_per_round) *
                        noise_multiplier * l2_norm_bound +
                        expected_clients_per_round * max_records_per_user) + 1:
      raise ValueError(f'The selected bit-width ({bits}) is too small for the '
                       f'given parameters (expected_clients_per_round = '
                       f'{expected_clients_per_round}, max_records_per_user = '
                       f'{max_records_per_user}, noise_multiplier = '
                       f'{noise_multiplier}) and will harm the accuracy of the '
                       f'result. Please decrease the '
                       f'`expected_clients_per_round` / `max_records_per_user` '
                       f'/ `noise_multiplier`, or increase `bits`.')
    nested_factory = secure.SecureSumFactory(
        upper_bound_threshold=2**bits - 1, lower_bound_threshold=0)
    nested_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=0,
        clip_range_upper=2**bits,
        inner_agg_factory=nested_factory)
    nested_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=-2**(bits - 1),
        clip_range_upper=2**(bits - 1),
        inner_agg_factory=nested_factory)

  # 2. DP operations.
  # Constructs `DifferentiallyPrivateFactory` according to the chosen
  # `dp_mechanism`.
  if dp_mechanism == 'central-gaussian':
    query = tfp.TreeRangeSumQuery.build_central_gaussian_query(
        l2_norm_bound, noise_multiplier * l2_norm_bound, arity)
    # If the inner `DifferentiallyPrivateFactory` uses `GaussianSumQuery`, then
    # the record is casted to `tf.float32` before feeding to the DP factory.
    cast_to_float = True
  elif dp_mechanism == 'distributed-discrete-gaussian':
    query = tfp.TreeRangeSumQuery.build_distributed_discrete_gaussian_query(
        l2_norm_bound, noise_multiplier * l2_norm_bound /
        math.sqrt(expected_clients_per_round), arity)
    # If the inner `DifferentiallyPrivateFactory` uses
    # `DistributedDiscreteGaussianQuery`, then the record is kept as `tf.int32`
    # before feeding to the DP factory.
    cast_to_float = False
  elif dp_mechanism == 'no-noise':
    inner_query = tfp.NoPrivacySumQuery()
    query = tfp.TreeRangeSumQuery(arity=arity, inner_query=inner_query)
    # If the inner `DifferentiallyPrivateFactory` uses `NoPrivacyQuery`, then
    # the record is kept as `tf.int32` before feeding to the DP factory.
    cast_to_float = False
  else:
    raise ValueError('Unexpected dp_mechanism.')
  nested_factory = differential_privacy.DifferentiallyPrivateFactory(
      query, nested_factory)

  # 3. Clip as specified by `clip_mechanism`.
  nested_factory = clipping_factory.HistogramClippingSumFactory(
      clip_mechanism=clip_mechanism,
      max_records_per_user=max_records_per_user,
      inner_agg_factory=nested_factory,
      cast_to_float=cast_to_float)

  return nested_factory