def compute_gradients()

in tensorflow_privacy/privacy/optimizers/dp_optimizer.py [0:0]


    def compute_gradients(self,
                          loss,
                          var_list,
                          gate_gradients=GATE_OP,
                          aggregation_method=None,
                          colocate_gradients_with_ops=False,
                          grad_loss=None,
                          gradient_tape=None):
      """DP-SGD version of base class method."""
      self._was_compute_gradients_called = True
      if self._global_state is None:
        self._global_state = self._dp_sum_query.initial_global_state()

      if callable(loss):
        # TF is running in Eager mode, check we received a vanilla tape.
        if not gradient_tape:
          raise ValueError('When in Eager mode, a tape needs to be passed.')

        vector_loss = loss()
        if self._num_microbatches is None:
          self._num_microbatches = tf.shape(input=vector_loss)[0]
        sample_state = self._dp_sum_query.initial_sample_state(var_list)
        microbatches_losses = tf.reshape(vector_loss,
                                         [self._num_microbatches, -1])
        sample_params = (
            self._dp_sum_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):
          """Process one microbatch (record) with privacy helper."""
          microbatch_loss = tf.reduce_mean(
              input_tensor=tf.gather(microbatches_losses, [i]))
          with gradient_tape.stop_recording():
            grads = gradient_tape.gradient(microbatch_loss, var_list)
          sample_state = self._dp_sum_query.accumulate_record(
              sample_params, sample_state, grads)
          return sample_state

        for idx in range(self._num_microbatches):
          sample_state = process_microbatch(idx, sample_state)

        grad_sums, self._global_state, _ = (
            self._dp_sum_query.get_noised_result(sample_state,
                                                 self._global_state))

        def normalize(v):
          return v / tf.cast(self._num_microbatches, tf.float32)

        final_grads = tf.nest.map_structure(normalize, grad_sums)

        grads_and_vars = list(zip(final_grads, var_list))
        return grads_and_vars

      else:
        # Note: it would be closer to the correct i.i.d. sampling of records if
        # we sampled each microbatch from the appropriate binomial distribution,
        # although that still wouldn't be quite correct because it would be
        # sampling from the dataset without replacement.
        if self._num_microbatches is None:
          self._num_microbatches = tf.shape(input=loss)[0]

        microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
        sample_params = (
            self._dp_sum_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):
          """Process one microbatch (record) with privacy helper."""
          self_super = super(DPOptimizerClass, self)

          mean_loss = tf.reduce_mean(
              input_tensor=tf.gather(microbatches_losses, [i]))

          if hasattr(self_super, 'compute_gradients'):
            # This case covers optimizers in tf.train.
            compute_gradients_fn = self_super.compute_gradients
          else:
            # This case covers Keras optimizers from optimizers_v2.
            compute_gradients_fn = self_super._compute_gradients  # pylint: disable=protected-access

          if gradient_tape:
            # This is intended to work for TF2 and may not work for TF1.
            with gradient_tape.stop_recording():
              grads_list = list(gradient_tape.gradient(mean_loss, var_list))
          else:
            grads, _ = zip(*compute_gradients_fn(
                mean_loss, var_list, gate_gradients, aggregation_method,
                colocate_gradients_with_ops, grad_loss))
            grads_list = list(grads)

          sample_state = self._dp_sum_query.accumulate_record(
              sample_params, sample_state, grads_list)
          return sample_state

        if var_list is None:
          var_list = (
              tf.trainable_variables() +
              tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

        sample_state = self._dp_sum_query.initial_sample_state(var_list)

        if self._unroll_microbatches:
          for idx in range(self._num_microbatches):
            sample_state = process_microbatch(idx, sample_state)
        else:
          # Use of while_loop here requires that sample_state be a nested
          # structure of tensors. In general, we would prefer to allow it to be
          # an arbitrary opaque type.
          cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
          body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]  # pylint: disable=line-too-long
          idx = tf.constant(0)
          _, sample_state = tf.while_loop(
              cond=cond_fn,
              body=body_fn,
              loop_vars=[idx, sample_state],
              parallel_iterations=self._while_loop_parallel_iterations)

        grad_sums, self._global_state, _ = (
            self._dp_sum_query.get_noised_result(sample_state,
                                                 self._global_state))

        def normalize(v):
          try:
            return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
          except TypeError:
            return None

        final_grads = tf.nest.map_structure(normalize, grad_sums)

        return list(zip(final_grads, var_list))