def get_cumsum_and_update()

in tensorflow_privacy/privacy/dp_query/tree_aggregation.py [0:0]


  def get_cumsum_and_update(self,
                            state: TreeState) -> Tuple[tf.Tensor, TreeState]:
    """Returns tree aggregated noise and updates `TreeState` for the next step.

    `TreeState` is updated to prepare for accepting the *next* leaf node. Note
    that `get_step_idx` can be called to get the current index of the leaf node
    before calling this function. This function accept state for the current
    leaf node and prepare for the next leaf node because TFF prefers to know
    the types of state at initialization.

    Args:
      state: `TreeState` for the current leaf node, index can be queried by
        `tree_aggregation.get_step_idx(state.level_buffer_idx)`.

    Returns:
      Tuple of (noise, state) where `noise` is generated by tree aggregated
      protocol for the cumulative sum of streaming data, and `state` is the
      updated `TreeState`.
    """

    level_buffer_idx, level_buffer, value_generator_state = (
        state.level_buffer_idx, state.level_buffer, state.value_generator_state)
    # We only publicize a combined function for updating state and returning
    # noised results because this DPQuery is designed for the streaming data,
    # and we only maintain a dynamic memory buffer of max size logT. Only the
    # the most recent noised results can be queried, and the queries are
    # expected to happen for every step in the streaming setting.
    cumsum = self._get_cumsum(level_buffer)

    new_level_buffer = tf.nest.map_structure(
        lambda x: tf.TensorArray(  # pylint: disable=g-long-lambda
            dtype=tf.float32,
            size=0,
            dynamic_size=True),
        level_buffer)
    new_level_buffer_idx = tf.TensorArray(
        dtype=tf.int32, size=0, dynamic_size=True)
    # `TreeState` stores the left child node necessary for computing the cumsum
    # noise. To update the buffer, let us find the lowest level that will switch
    # from a right child (not in the buffer) to a left child.
    level_idx = 0  # new leaf node starts from level 0
    while tf.less(level_idx, len(level_buffer_idx)) and tf.equal(
        level_idx, level_buffer_idx[level_idx]):
      level_idx += 1
    # Left child nodes for the level lower than `level_idx` will be removed
    # and a new node will be created at `level_idx`.
    write_buffer_idx = 0
    new_level_buffer_idx = new_level_buffer_idx.write(write_buffer_idx,
                                                      level_idx)
    new_value, value_generator_state = self.value_generator.next(
        value_generator_state)
    new_level_buffer = tf.nest.map_structure(
        lambda x, y: x.write(write_buffer_idx, y), new_level_buffer, new_value)
    write_buffer_idx += 1
    # Buffer index will now different from level index for the old `TreeState`
    # i.e., `level_buffer_idx[level_idx] != level_idx`. Rename parameter to
    # buffer index for clarity.
    buffer_idx = level_idx
    while tf.less(buffer_idx, len(level_buffer_idx)):
      new_level_buffer_idx = new_level_buffer_idx.write(
          write_buffer_idx, level_buffer_idx[buffer_idx])
      new_level_buffer = tf.nest.map_structure(
          lambda nb, b: nb.write(write_buffer_idx, b[buffer_idx]),
          new_level_buffer, level_buffer)
      buffer_idx += 1
      write_buffer_idx += 1
    new_level_buffer_idx = new_level_buffer_idx.stack()
    new_level_buffer = tf.nest.map_structure(lambda x: x.stack(),
                                             new_level_buffer)
    new_state = TreeState(
        level_buffer=new_level_buffer,
        level_buffer_idx=new_level_buffer_idx,
        value_generator_state=value_generator_state)
    return cumsum, new_state