def client_update()

in tensorflow_federated/python/learning/reconstruction/training_process.py [0:0]


  def client_update(dataset, initial_model_weights):
    """Performs client local model optimization.

    Args:
      dataset: A `tf.data.Dataset` that provides training examples.
      initial_model_weights: A `tff.learning.ModelWeights` containing the
        starting global trainable and non-trainable weights.

    Returns:
      A `ClientOutput`.
    """
    with tf.init_scope():
      model = model_fn()

      metrics = []
      if metrics_fn is not None:
        metrics.extend(metrics_fn())
      # To be used to calculate example-weighted mean across batches and
      # clients.
      metrics.append(keras_utils.MeanLossMetric(loss_fn()))
      # To be used to calculate batch loss for model updates.
      client_loss = loss_fn()

    global_model_weights = reconstruction_utils.get_global_variables(model)
    local_model_weights = reconstruction_utils.get_local_variables(model)
    tf.nest.map_structure(lambda a, b: a.assign(b), global_model_weights,
                          initial_model_weights)
    client_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
        client_optimizer_fn,
        global_model_weights.trainable,
        disjoint_init_and_next=False)
    reconstruction_optimizer = keras_optimizer.build_or_verify_tff_optimizer(
        reconstruction_optimizer_fn,
        local_model_weights.trainable,
        disjoint_init_and_next=False)

    @tf.function
    def reconstruction_reduce_fn(state, batch):
      """Runs reconstruction training on local client batch."""
      num_examples_sum, optimizer_state = state
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch, training=True)
        batch_loss = client_loss(
            y_true=output.labels, y_pred=output.predictions)

      gradients = tape.gradient(batch_loss, local_model_weights.trainable)
      optimizer_state, updated_weights = reconstruction_optimizer.next(
          optimizer_state, _flat_tuple(local_model_weights.trainable),
          _flat_tuple(gradients))
      updated_weights = tf.nest.pack_sequence_as(local_model_weights.trainable,
                                                 updated_weights)
      if not isinstance(reconstruction_optimizer,
                        keras_optimizer.KerasOptimizer):
        # Keras optimizer mutates model variables within the `next` step.
        tf.nest.map_structure(lambda a, b: a.assign(b),
                              local_model_weights.trainable, updated_weights)

      return num_examples_sum + output.num_examples, optimizer_state

    @tf.function
    def train_reduce_fn(state, batch):
      """Runs one step of client optimizer on local client batch."""
      num_examples_sum, optimizer_state = state
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch, training=True)
        batch_loss = client_loss(
            y_true=output.labels, y_pred=output.predictions)

      gradients = tape.gradient(batch_loss, global_model_weights.trainable)
      optimizer_state, updated_weights = client_optimizer.next(
          optimizer_state, _flat_tuple(global_model_weights.trainable),
          _flat_tuple(gradients))
      updated_weights = tf.nest.pack_sequence_as(global_model_weights.trainable,
                                                 updated_weights)
      if not isinstance(client_optimizer, keras_optimizer.KerasOptimizer):
        # Keras optimizer mutates model variables within the `next` step.
        tf.nest.map_structure(lambda a, b: a.assign(b),
                              global_model_weights.trainable, updated_weights)

      # Update each metric.
      for metric in metrics:
        metric.update_state(y_true=output.labels, y_pred=output.predictions)

      return num_examples_sum + output.num_examples, optimizer_state

    recon_dataset, post_recon_dataset = dataset_split_fn(dataset)

    # If needed, do reconstruction, training the local variables while keeping
    # the global ones frozen.
    if local_model_weights.trainable:
      # Ignore output number of examples used in reconstruction, since this
      # isn't included in `client_weight`.
      def initial_state_reconstruction_reduce():
        trainable_tensor_specs = tf.nest.map_structure(
            lambda v: tf.TensorSpec(v.shape, v.dtype),
            local_model_weights.trainable)
        # TODO(b/161529310): We flatten and convert the trainable specs to
        # tuple, as the data iteration pattern would try to stack the tensors
        # in a list.
        return tf.constant(0), reconstruction_optimizer.initialize(
            _flat_tuple(trainable_tensor_specs))

      recon_dataset.reduce(
          initial_state=initial_state_reconstruction_reduce(),
          reduce_func=reconstruction_reduce_fn)

    # Train the global variables, keeping local variables frozen.
    def initial_state_train_reduce():
      trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype),
          global_model_weights.trainable)
      # TODO(b/161529310): We flatten and convert the trainable specs to
      # tuple, as the data iteration pattern would try to stack the tensors
      # in a list.
      return tf.constant(0), client_optimizer.initialize(
          _flat_tuple(trainable_tensor_specs))

    num_examples_sum, _ = post_recon_dataset.reduce(
        initial_state=initial_state_train_reduce(), reduce_func=train_reduce_fn)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          global_model_weights.trainable,
                                          initial_model_weights.trainable)

    # We ignore the update if the weights_delta is non finite.
    weights_delta, has_non_finite_weight = (
        tensor_utils.zero_all_if_any_non_finite(weights_delta))

    model_local_outputs = keras_utils.read_metric_variables(metrics)

    if has_non_finite_weight > 0:
      client_weight = tf.constant(0.0, dtype=tf.float32)
    elif client_weighting is client_weight_lib.ClientWeighting.NUM_EXAMPLES:
      client_weight = tf.cast(num_examples_sum, dtype=tf.float32)
    elif client_weighting is client_weight_lib.ClientWeighting.UNIFORM:
      client_weight = tf.constant(1.0, dtype=tf.float32)
    else:
      client_weight = client_weighting(model_local_outputs)

    return ClientOutput(weights_delta, client_weight, model_local_outputs)