in tensorflow_privacy/privacy/analysis/tree_aggregation_accountant.py [0:0]
def _tree_sensitivity_square_sum(num_participation: int, min_separation: int,
start: int, end: int, size: int) -> float:
"""Compute the worst-case sum of sensitivtiy square for `num_participation`.
This is the key algorithm for DP accounting for DP-FTRL tree aggregation
without restart, which recurrently counts the worst-case occurence of a sample
in all the nodes in a tree. This implements a dynamic programming algorithm
that exhausts the possible `num_participation` appearance of a sample in
`size` leaf nodes. See Appendix D.2 (DP-FTRL-NoTreeRestart) of
"Practical and Private (Deep) Learning without Sampling or Shuffling"
https://arxiv.org/abs/2103.00039.
Args:
num_participation: The number of times a sample will appear.
min_separation: The minimum number of nodes between two appearance of a
sample. If a sample appears in consecutive x, y size in a streaming
setting, then `min_separation=y-x-1`.
start: The first appearance of the sample is after `start` steps.
end: The sample won't appear in the `end` steps after given `size` steps.
size: Total number of steps (leaf nodes in tree aggregation).
Returns:
The worst-case sum of sensitivity square for the given input.
"""
if not _check_possible_tree_participation(num_participation, min_separation,
start, end, size):
sum_value = -np.inf
elif num_participation == 0:
sum_value = 0.
elif num_participation == 1 and size == 1:
sum_value = 1.
else:
size_log2 = math.log2(size)
max_2power = math.floor(size_log2)
if max_2power == size_log2:
sum_value = num_participation**2
max_2power -= 1
else:
sum_value = 0.
candidate_sum = []
# i is the `num_participation` in the right subtree
for i in range(num_participation + 1):
# j is the `start` in the right subtree
for j in range(min_separation + 1):
left_sum = _tree_sensitivity_square_sum(
num_participation=num_participation - i,
min_separation=min_separation,
start=start,
end=j,
size=2**max_2power)
if np.isinf(left_sum):
candidate_sum.append(-np.inf)
continue # Early pruning for dynamic programming
right_sum = _tree_sensitivity_square_sum(
num_participation=i,
min_separation=min_separation,
start=j,
end=end,
size=size - 2**max_2power)
candidate_sum.append(left_sum + right_sum)
sum_value += max(candidate_sum)
return sum_value