def _tree_sensitivity_square_sum()

in tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py [0:0]


def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
                                 start: int, end: int, size: int) -> float:
  """Compute the worst-case sum of sensitivtiy square for `num_participation`.

  This is the key algorithm for DP accounting for DP-FTRL tree aggregation
  without restart, which recurrently counts the worst-case occurence of a sample
  in all the nodes in a tree. This implements a dynamic programming algorithm
  that exhausts the possible `num_participation` appearance of a sample in
  `size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
  "Practical and Private (Deep) Learning without Sampling or Shuffling"
  https://arxiv.org/abs/2103.00039.

  Args:
    num_participation: The number of times a sample will appear.
    min_separation: The minimum number of nodes between two appearance of a
      sample. If a sample appears in consecutive x, y size in a streaming
      setting, then `min_separation=y-x-1`.
    start:  The first appearance of the sample is after `start` steps.
    end: The sample won't appear in the `end` steps after given `size` steps.
    size: Total number of steps (leaf nodes in tree aggregation).

  Returns:
    The worst-case sum of sensitivity square for the given input.
  """
  if not _check_possible_tree_participation(num_participation, min_separation,
                                            start, end, size):
    sum_value = -np.inf
  elif num_participation == 0:
    sum_value = 0.
  elif num_participation == 1 and size == 1:
    sum_value = 1.
  else:
    size_log2 = math.log2(size)
    max_2power = math.floor(size_log2)
    if max_2power == size_log2:
      sum_value = num_participation**2
      max_2power -= 1
    else:
      sum_value = 0.
    candidate_sum = []
    # i is the `num_participation` in the right subtree
    for i in range(num_participation + 1):
      # j is the `start` in the right subtree
      for j in range(min_separation + 1):
        left_sum = _tree_sensitivity_square_sum(
            num_participation=num_participation - i,
            min_separation=min_separation,
            start=start,
            end=j,
            size=2**max_2power)
        if np.isinf(left_sum):
          candidate_sum.append(-np.inf)
          continue  # Early pruning for dynamic programming
        right_sum = _tree_sensitivity_square_sum(
            num_participation=i,
            min_separation=min_separation,
            start=j,
            end=end,
            size=size - 2**max_2power)
        candidate_sum.append(left_sum + right_sum)
    sum_value += max(candidate_sum)
  return sum_value