def save_functional_model()

in tensorflow_federated/python/learning/models/serialization.py [0:0]


def save_functional_model(functional_model: functional.FunctionalModel,
                          path: str):
  """Serializes a `FunctionalModel` as a `tf.SavedModel` to `path`.

  Args:
    functional_model: A `tff.learning.models.FunctionalModel`.
    path: A `str` directory path to serialize the model to.
  """
  m = tf.Module()
  # Serialize the initial_weights values as a tf.function that creates a
  # structure of tensors with the initial weights. This way we can add it to the
  # tf.SavedModel and call it to create initial weights after deserialization.
  create_initial_weights = lambda: functional_model.initial_weights
  with tf.Graph().as_default():
    concrete_structured_fn = tf.function(
        create_initial_weights).get_concrete_function()
  model_weights_tensor_specs = tf.nest.map_structure(
      tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
  initial_weights_result_type_spec = type_serialization.serialize_type(
      computation_types.to_type(model_weights_tensor_specs))
  m.create_initial_weights_type_spec = tf.Variable(
      initial_weights_result_type_spec.SerializeToString(deterministic=True))

  def flat_initial_weights():
    return tf.nest.flatten(create_initial_weights())

  with tf.Graph().as_default():
    m.create_initial_weights = tf.function(
        flat_initial_weights).get_concrete_function()

  # Serialize forward pass concretely, once for training and once for
  # non-training.
  # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
  # the need to for serializing two different function graphs.
  def make_concrete_flat_forward_pass(training: bool):
    """Create a concrete forward_pass function that has flattened output.

    Args:
      training: A boolean indicating whether this is a call in a training loop,
        or evaluation loop.

    Returns:
      A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
      buffer message documenting the the result structure returned by the
      concrete function.
    """
    # Save the un-flattened type spec for deserialization later.
    # Note: `training` is a Python boolean, which gets "curried", in a sense,
    # during function conretization. The resulting concrete function only has
    # parameters for `model_weights` and `batch_input`, which are
    # `tf.TensorSpec` structures here.
    with tf.Graph().as_default():
      concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
          model_weights_tensor_specs,
          functional_model.input_spec,
          # Note: training does not appear in the resulting concrete function.
          training=training)
    output_tensor_spec_structure = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
    result_type_spec = type_serialization.serialize_type(
        computation_types.to_type(output_tensor_spec_structure))

    @tf.function
    def flat_forward_pass(model_weights, batch_input, training):
      return tf.nest.flatten(
          functional_model.forward_pass(model_weights, batch_input, training))

    with tf.Graph().as_default():
      flat_concrete_fn = flat_forward_pass.get_concrete_function(
          model_weights_tensor_specs,
          functional_model.input_spec,
          # Note: training does not appear in the resulting concrete function.
          training=training)
    return flat_concrete_fn, result_type_spec

  fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass(
      training=True)
  m.flat_forward_pass_training = fw_pass_training
  m.forward_pass_training_type_spec = tf.Variable(
      fw_pass_training_type_spec.SerializeToString(deterministic=True),
      trainable=False)

  fw_pass_inference, fw_pass_inference_type_spec = make_concrete_flat_forward_pass(
      training=False)
  m.flat_forward_pass_inference = fw_pass_inference
  m.forward_pass_inference_type_spec = tf.Variable(
      fw_pass_inference_type_spec.SerializeToString(deterministic=True),
      trainable=False)

  # Serialize predict_on_batch, once for training, once for non-training.
  x_type = functional_model.input_spec[0]

  # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
  # the need to for serializing two different function graphs.
  def make_concrete_flat_predict_on_batch(training: bool):
    """Create a concrete predict_on_batch function that has flattened output.

    Args:
      training: A boolean indicating whether this is a call in a training loop,
        or evaluation loop.

    Returns:
      A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
      buffer message documenting the the result structure returned by the
      concrete
      function.
    """
    # Save the un-flattened type spec for deserialization later.
    # Note: `training` is a Python boolean, which gets "curried", in a sense,
    # during function conretization. The resulting concrete function only has
    # parameters for `model_weights` and `batch_input`, which are
    # `tf.TensorSpec` structures here.
    concrete_structured_fn = tf.function(
        functional_model.predict_on_batch).get_concrete_function(
            model_weights_tensor_specs,
            x_type,
            # Note: training does not appear in the resulting concrete function.
            training=training)
    output_tensor_spec_structure = tf.nest.map_structure(
        tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
    result_type_spec = type_serialization.serialize_type(
        computation_types.to_type(output_tensor_spec_structure))

    @tf.function
    def flat_predict_on_batch(model_weights, x, training):
      return tf.nest.flatten(
          functional_model.predict_on_batch(model_weights, x, training))

    flat_concrete_fn = tf.function(flat_predict_on_batch).get_concrete_function(
        model_weights_tensor_specs,
        x_type,
        # Note: training does not appear in the resulting concrete function.
        training=training)
    return flat_concrete_fn, result_type_spec

  with tf.Graph().as_default():
    predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
        training=True)
  m.predict_on_batch_training = predict_training
  m.predict_on_batch_training_type_spec = tf.Variable(
      predict_training_type_spec.SerializeToString(deterministic=True),
      trainable=False)

  with tf.Graph().as_default():
    predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
        training=False)
  m.predict_on_batch_inference = predict_inference
  m.predict_on_batch_inference_type_spec = tf.Variable(
      predict_inference_type_spec.SerializeToString(deterministic=True),
      trainable=False)

  # Serialize TFF values as string variables that contain the serialized
  # protos from the computation or the type.
  m.serialized_input_spec = tf.Variable(
      type_serialization.serialize_type(
          computation_types.to_type(
              functional_model.input_spec)).SerializeToString(
                  deterministic=True),
      trainable=False)

  # Save everything
  _save_tensorflow_module(m, path)