in tensorflow_federated/python/learning/models/serialization.py [0:0]
def save_functional_model(functional_model: functional.FunctionalModel,
path: str):
"""Serializes a `FunctionalModel` as a `tf.SavedModel` to `path`.
Args:
functional_model: A `tff.learning.models.FunctionalModel`.
path: A `str` directory path to serialize the model to.
"""
m = tf.Module()
# Serialize the initial_weights values as a tf.function that creates a
# structure of tensors with the initial weights. This way we can add it to the
# tf.SavedModel and call it to create initial weights after deserialization.
create_initial_weights = lambda: functional_model.initial_weights
with tf.Graph().as_default():
concrete_structured_fn = tf.function(
create_initial_weights).get_concrete_function()
model_weights_tensor_specs = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
initial_weights_result_type_spec = type_serialization.serialize_type(
computation_types.to_type(model_weights_tensor_specs))
m.create_initial_weights_type_spec = tf.Variable(
initial_weights_result_type_spec.SerializeToString(deterministic=True))
def flat_initial_weights():
return tf.nest.flatten(create_initial_weights())
with tf.Graph().as_default():
m.create_initial_weights = tf.function(
flat_initial_weights).get_concrete_function()
# Serialize forward pass concretely, once for training and once for
# non-training.
# TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
# the need to for serializing two different function graphs.
def make_concrete_flat_forward_pass(training: bool):
"""Create a concrete forward_pass function that has flattened output.
Args:
training: A boolean indicating whether this is a call in a training loop,
or evaluation loop.
Returns:
A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
buffer message documenting the the result structure returned by the
concrete function.
"""
# Save the un-flattened type spec for deserialization later.
# Note: `training` is a Python boolean, which gets "curried", in a sense,
# during function conretization. The resulting concrete function only has
# parameters for `model_weights` and `batch_input`, which are
# `tf.TensorSpec` structures here.
with tf.Graph().as_default():
concrete_structured_fn = functional_model.forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
output_tensor_spec_structure = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
result_type_spec = type_serialization.serialize_type(
computation_types.to_type(output_tensor_spec_structure))
@tf.function
def flat_forward_pass(model_weights, batch_input, training):
return tf.nest.flatten(
functional_model.forward_pass(model_weights, batch_input, training))
with tf.Graph().as_default():
flat_concrete_fn = flat_forward_pass.get_concrete_function(
model_weights_tensor_specs,
functional_model.input_spec,
# Note: training does not appear in the resulting concrete function.
training=training)
return flat_concrete_fn, result_type_spec
fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass(
training=True)
m.flat_forward_pass_training = fw_pass_training
m.forward_pass_training_type_spec = tf.Variable(
fw_pass_training_type_spec.SerializeToString(deterministic=True),
trainable=False)
fw_pass_inference, fw_pass_inference_type_spec = make_concrete_flat_forward_pass(
training=False)
m.flat_forward_pass_inference = fw_pass_inference
m.forward_pass_inference_type_spec = tf.Variable(
fw_pass_inference_type_spec.SerializeToString(deterministic=True),
trainable=False)
# Serialize predict_on_batch, once for training, once for non-training.
x_type = functional_model.input_spec[0]
# TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove
# the need to for serializing two different function graphs.
def make_concrete_flat_predict_on_batch(training: bool):
"""Create a concrete predict_on_batch function that has flattened output.
Args:
training: A boolean indicating whether this is a call in a training loop,
or evaluation loop.
Returns:
A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol
buffer message documenting the the result structure returned by the
concrete
function.
"""
# Save the un-flattened type spec for deserialization later.
# Note: `training` is a Python boolean, which gets "curried", in a sense,
# during function conretization. The resulting concrete function only has
# parameters for `model_weights` and `batch_input`, which are
# `tf.TensorSpec` structures here.
concrete_structured_fn = tf.function(
functional_model.predict_on_batch).get_concrete_function(
model_weights_tensor_specs,
x_type,
# Note: training does not appear in the resulting concrete function.
training=training)
output_tensor_spec_structure = tf.nest.map_structure(
tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs)
result_type_spec = type_serialization.serialize_type(
computation_types.to_type(output_tensor_spec_structure))
@tf.function
def flat_predict_on_batch(model_weights, x, training):
return tf.nest.flatten(
functional_model.predict_on_batch(model_weights, x, training))
flat_concrete_fn = tf.function(flat_predict_on_batch).get_concrete_function(
model_weights_tensor_specs,
x_type,
# Note: training does not appear in the resulting concrete function.
training=training)
return flat_concrete_fn, result_type_spec
with tf.Graph().as_default():
predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch(
training=True)
m.predict_on_batch_training = predict_training
m.predict_on_batch_training_type_spec = tf.Variable(
predict_training_type_spec.SerializeToString(deterministic=True),
trainable=False)
with tf.Graph().as_default():
predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch(
training=False)
m.predict_on_batch_inference = predict_inference
m.predict_on_batch_inference_type_spec = tf.Variable(
predict_inference_type_spec.SerializeToString(deterministic=True),
trainable=False)
# Serialize TFF values as string variables that contain the serialized
# protos from the computation or the type.
m.serialized_input_spec = tf.Variable(
type_serialization.serialize_type(
computation_types.to_type(
functional_model.input_spec)).SerializeToString(
deterministic=True),
trainable=False)
# Save everything
_save_tensorflow_module(m, path)