in tensorflow_privacy/privacy/dp_query/tree_aggregation.py [0:0]
def get_cumsum_and_update(self,
state: TreeState) -> Tuple[tf.Tensor, TreeState]:
"""Returns tree aggregated noise and updates `TreeState` for the next step.
`TreeState` is updated to prepare for accepting the *next* leaf node. Note
that `get_step_idx` can be called to get the current index of the leaf node
before calling this function. This function accept state for the current
leaf node and prepare for the next leaf node because TFF prefers to know
the types of state at initialization. Note that the value of new node in
`TreeState.level_buffer` will depend on its two children, and is updated
from bottom up for the right child.
Args:
state: `TreeState` for the current leaf node, index can be queried by
`tree_aggregation.get_step_idx(state.level_buffer_idx)`.
Returns:
Tuple of (noise, state) where `noise` is generated by tree aggregated
protocol for the cumulative sum of streaming data, and `state` is the
updated `TreeState`..
"""
# We only publicize a combined function for updating state and returning
# noised results because this DPQuery is designed for the streaming data,
# and we only maintain a dynamic memory buffer of max size logT. Only the
# the most recent noised results can be queried, and the queries are
# expected to happen for every step in the streaming setting.
cumsum = self._get_cumsum(state)
level_buffer_idx, level_buffer, value_generator_state = (
state.level_buffer_idx, state.level_buffer, state.value_generator_state)
new_level_buffer = tf.nest.map_structure(
lambda x: tf.TensorArray( # pylint: disable=g-long-lambda
dtype=tf.float32,
size=0,
dynamic_size=True),
level_buffer)
new_level_buffer_idx = tf.TensorArray(
dtype=tf.int32, size=0, dynamic_size=True)
# `TreeState` stores the left child node necessary for computing the cumsum
# noise. To update the buffer, let us find the lowest level that will switch
# from a right child (not in the buffer) to a left child.
level_idx = 0 # new leaf node starts from level 0
new_value, value_generator_state = self.value_generator.next(
value_generator_state)
while tf.less(level_idx, len(level_buffer_idx)) and tf.equal(
level_idx, level_buffer_idx[level_idx]):
# Recursively update if the current node is a right child.
node_value, value_generator_state = self.value_generator.next(
value_generator_state)
new_value = tf.nest.map_structure(
lambda l, r, n: 0.5 * (l[level_idx] + r) + n, level_buffer, new_value,
node_value)
level_idx += 1
# A new (left) node will be created at `level_idx`.
write_buffer_idx = 0
new_level_buffer_idx = new_level_buffer_idx.write(write_buffer_idx,
level_idx)
new_level_buffer = tf.nest.map_structure(
lambda x, y: x.write(write_buffer_idx, y), new_level_buffer, new_value)
write_buffer_idx += 1
# Buffer index will now different from level index for the old `TreeState`
# i.e., `level_buffer_idx[level_idx] != level_idx`. Rename parameter to
# buffer index for clarity.
buffer_idx = level_idx
while tf.less(buffer_idx, len(level_buffer_idx)):
new_level_buffer_idx = new_level_buffer_idx.write(
write_buffer_idx, level_buffer_idx[buffer_idx])
new_level_buffer = tf.nest.map_structure(
lambda nb, b: nb.write(write_buffer_idx, b[buffer_idx]),
new_level_buffer, level_buffer)
buffer_idx += 1
write_buffer_idx += 1
new_level_buffer_idx = new_level_buffer_idx.stack()
new_level_buffer = tf.nest.map_structure(lambda x: x.stack(),
new_level_buffer)
new_state = TreeState(
level_buffer=new_level_buffer,
level_buffer_idx=new_level_buffer_idx,
value_generator_state=value_generator_state)
return cumsum, new_state