in tensorflow_privacy/privacy/optimizers/dp_optimizer.py [0:0]
def compute_gradients(self,
loss,
var_list,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None,
gradient_tape=None):
"""DP-SGD version of base class method."""
self._was_compute_gradients_called = True
if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state()
if callable(loss):
# TF is running in Eager mode, check we received a vanilla tape.
if not gradient_tape:
raise ValueError('When in Eager mode, a tape needs to be passed.')
vector_loss = loss()
if self._num_microbatches is None:
self._num_microbatches = tf.shape(input=vector_loss)[0]
sample_state = self._dp_sum_query.initial_sample_state(var_list)
microbatches_losses = tf.reshape(vector_loss,
[self._num_microbatches, -1])
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
microbatch_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses, [i]))
with gradient_tape.stop_recording():
grads = gradient_tape.gradient(microbatch_loss, var_list)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads)
return sample_state
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
return v / tf.cast(self._num_microbatches, tf.float32)
final_grads = tf.nest.map_structure(normalize, grad_sums)
grads_and_vars = list(zip(final_grads, var_list))
return grads_and_vars
else:
# Note: it would be closer to the correct i.i.d. sampling of records if
# we sampled each microbatch from the appropriate binomial distribution,
# although that still wouldn't be quite correct because it would be
# sampling from the dataset without replacement.
if self._num_microbatches is None:
self._num_microbatches = tf.shape(input=loss)[0]
microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))
def process_microbatch(i, sample_state):
"""Process one microbatch (record) with privacy helper."""
self_super = super(DPOptimizerClass, self)
mean_loss = tf.reduce_mean(
input_tensor=tf.gather(microbatches_losses, [i]))
if hasattr(self_super, 'compute_gradients'):
# This case covers optimizers in tf.train.
compute_gradients_fn = self_super.compute_gradients
else:
# This case covers Keras optimizers from optimizers_v2.
compute_gradients_fn = self_super._compute_gradients # pylint: disable=protected-access
if gradient_tape:
# This is intended to work for TF2 and may not work for TF1.
with gradient_tape.stop_recording():
grads_list = list(gradient_tape.gradient(mean_loss, var_list))
else:
grads, _ = zip(*compute_gradients_fn(
mean_loss, var_list, gate_gradients, aggregation_method,
colocate_gradients_with_ops, grad_loss))
grads_list = list(grads)
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, grads_list)
return sample_state
if var_list is None:
var_list = (
tf.trainable_variables() +
tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
sample_state = self._dp_sum_query.initial_sample_state(var_list)
if self._unroll_microbatches:
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)
else:
# Use of while_loop here requires that sample_state be a nested
# structure of tensors. In general, we would prefer to allow it to be
# an arbitrary opaque type.
cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long
idx = tf.constant(0)
_, sample_state = tf.while_loop(
cond=cond_fn,
body=body_fn,
loop_vars=[idx, sample_state],
parallel_iterations=self._while_loop_parallel_iterations)
grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
def normalize(v):
try:
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
except TypeError:
return None
final_grads = tf.nest.map_structure(normalize, grad_sums)
return list(zip(final_grads, var_list))