in tensorflow_federated/python/learning/reconstruction/training_process.py [0:0]
def client_update(dataset, initial_model_weights):
"""Performs client local model optimization.
Args:
dataset: A `tf.data.Dataset` that provides training examples.
initial_model_weights: A `tff.learning.ModelWeights` containing the
starting global trainable and non-trainable weights.
Returns:
A `ClientOutput`.
"""
with tf.init_scope():
model = model_fn()
metrics = []
if metrics_fn is not None:
metrics.extend(metrics_fn())
# To be used to calculate example-weighted mean across batches and
# clients.
metrics.append(keras_utils.MeanLossMetric(loss_fn()))
# To be used to calculate batch loss for model updates.
client_loss = loss_fn()
global_model_weights = reconstruction_utils.get_global_variables(model)
local_model_weights = reconstruction_utils.get_local_variables(model)
tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
initial_model_weights)
client_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
client_optimizer_fn,
global_model_weights.trainable,
disjoint_init_and_next=False)
reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
reconstruction_optimizer_fn,
local_model_weights.trainable,
disjoint_init_and_next=False)
@tf.function
def reconstruction_reduce_fn(state, batch):
"""Runs reconstruction training on local client batch."""
num_examples_sum, optimizer_state = state
with tf.GradientTape() as tape:
output = model.forward_pass(batch, training=True)
batch_loss = client_loss(
y_true=output.labels, y_pred=output.predictions)
gradients = tape.gradient(batch_loss, local_model_weights.trainable)
optimizer_state, updated_weights = reconstruction_optimizer.next(
optimizer_state, _flat_tuple(local_model_weights.trainable),
_flat_tuple(gradients))
updated_weights = tf.nest.pack_sequence_as(local_model_weights.trainable,
updated_weights)
if not isinstance(reconstruction_optimizer,
keras_optimizer.KerasOptimizer):
# Keras optimizer mutates model variables within the `next` step.
tf.nest.map_structure(lambda a, b: a.assign(b),
local_model_weights.trainable, updated_weights)
return num_examples_sum + output.num_examples, optimizer_state
@tf.function
def train_reduce_fn(state, batch):
"""Runs one step of client optimizer on local client batch."""
num_examples_sum, optimizer_state = state
with tf.GradientTape() as tape:
output = model.forward_pass(batch, training=True)
batch_loss = client_loss(
y_true=output.labels, y_pred=output.predictions)
gradients = tape.gradient(batch_loss, global_model_weights.trainable)
optimizer_state, updated_weights = client_optimizer.next(
optimizer_state, _flat_tuple(global_model_weights.trainable),
_flat_tuple(gradients))
updated_weights = tf.nest.pack_sequence_as(global_model_weights.trainable,
updated_weights)
if not isinstance(client_optimizer, keras_optimizer.KerasOptimizer):
# Keras optimizer mutates model variables within the `next` step.
tf.nest.map_structure(lambda a, b: a.assign(b),
global_model_weights.trainable, updated_weights)
# Update each metric.
for metric in metrics:
metric.update_state(y_true=output.labels, y_pred=output.predictions)
return num_examples_sum + output.num_examples, optimizer_state
recon_dataset, post_recon_dataset = dataset_split_fn(dataset)
# If needed, do reconstruction, training the local variables while keeping
# the global ones frozen.
if local_model_weights.trainable:
# Ignore output number of examples used in reconstruction, since this
# isn't included in `client_weight`.
def initial_state_reconstruction_reduce():
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype),
local_model_weights.trainable)
# TODO(b/161529310): We flatten and convert the trainable specs to
# tuple, as the data iteration pattern would try to stack the tensors
# in a list.
return tf.constant(0), reconstruction_optimizer.initialize(
_flat_tuple(trainable_tensor_specs))
recon_dataset.reduce(
initial_state=initial_state_reconstruction_reduce(),
reduce_func=reconstruction_reduce_fn)
# Train the global variables, keeping local variables frozen.
def initial_state_train_reduce():
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype),
global_model_weights.trainable)
# TODO(b/161529310): We flatten and convert the trainable specs to
# tuple, as the data iteration pattern would try to stack the tensors
# in a list.
return tf.constant(0), client_optimizer.initialize(
_flat_tuple(trainable_tensor_specs))
num_examples_sum, _ = post_recon_dataset.reduce(
initial_state=initial_state_train_reduce(), reduce_func=train_reduce_fn)
weights_delta = tf.nest.map_structure(lambda a, b: a - b,
global_model_weights.trainable,
initial_model_weights.trainable)
# We ignore the update if the weights_delta is non finite.
weights_delta, has_non_finite_weight = (
tensor_utils.zero_all_if_any_non_finite(weights_delta))
model_local_outputs = keras_utils.read_metric_variables(metrics)
if has_non_finite_weight > 0:
client_weight = tf.constant(0.0, dtype=tf.float32)
elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
client_weight = tf.cast(num_examples_sum, dtype=tf.float32)
elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
client_weight = tf.constant(1.0, dtype=tf.float32)
else:
client_weight = client_weighting(model_local_outputs)
return ClientOutput(weights_delta, client_weight, model_local_outputs)