in tensorflow_federated/python/learning/models/serialization.py [0:0]
def save(model: model_lib.Model, path: str, input_type=None) -> None:
"""Serializes `model` as a TensorFlow SavedModel to `path`.
The resulting SavedModel will contain the default serving signature, which
can be used with the TFLite converter to create a TFLite flatbuffer for
inference.
NOTE: The model returned by `tff.learning.models.load` will _not_ be the same
Python type as the saved model. If the model serialized using this method is
a subclass of `tff.learning.Model`, that subclass is _not_ returned. All
method behavior is retained, but the Python type does not cross serialization
boundaries. The return type of `metric_finalizers` will be an OrderedDict of
str to `tff.tf_computation` (annotated TFF computations) which could be
different from that of the model before serialization.
Args:
model: The `tff.learning.Model` to save.
path: The `str` directory path to serialize the model to.
input_type: An optional structure of `tf.TensorSpec`s representing the
expected input of `model.predict_on_batch`, to override reading from
`model.input_spec`. Typically this will be similar to `model.input_spec`,
with any example labels removed. If None, default to
`model.input_spec['x']` if the input_spec is a mapping, otherwise default
to `model.input_spec[0]`.
"""
py_typecheck.check_type(model, model_lib.Model)
py_typecheck.check_type(path, str)
if not path:
raise ValueError('`path` must be a non-empty string, cannot serialize '
'models without an output path.')
if isinstance(model, _LoadedSavedModel):
# If we're saving a previously loaded model, we can simply use the module
# already internal to the Model.
_save_tensorflow_module(model._loaded_module, path) # pylint: disable=protected-access
return
m = tf.Module()
# We prefixed with `tff_` because `trainable_variables` is an attribute
# reserved by `tf.Module`.
m.tff_trainable_variables = model.trainable_variables
m.tff_non_trainable_variables = model.non_trainable_variables
m.tff_local_variables = model.local_variables
# Serialize forward_pass. We must get two concrete versions of the
# function, as the `training` argument is a Python value that changes the
# graph computation. We serialize the output type so that we can repack the
# flattened values after loaded the saved model.
forward_pass_training = _make_concrete_flat_output_fn(
functools.partial(model.forward_pass, training=True), model.input_spec)
m.flat_forward_pass_training = forward_pass_training[0]
m.forward_pass_training_type_spec = tf.Variable(
forward_pass_training[1].SerializeToString(deterministic=True),
trainable=False)
forward_pass_inference = _make_concrete_flat_output_fn(
functools.partial(model.forward_pass, training=False), model.input_spec)
m.flat_forward_pass_inference = forward_pass_inference[0]
m.forward_pass_inference_type_spec = tf.Variable(
forward_pass_inference[1].SerializeToString(deterministic=True),
trainable=False)
# Get model prediction input type. If `None`, default to assuming the 'x' key
# or first element of the model input spec is the input.
if input_type is None:
if isinstance(model.input_spec, collections.abc.Mapping):
input_type = model.input_spec['x']
else:
input_type = model.input_spec[0]
# Serialize predict_on_batch. We must get two concrete versions of the
# function, as the `training` argument is a Python value that changes the
# graph computation.
predict_on_batch_training = _make_concrete_flat_output_fn(
functools.partial(model.predict_on_batch, training=True), input_type)
m.predict_on_batch_training = predict_on_batch_training[0]
m.predict_on_batch_training_type_spec = tf.Variable(
predict_on_batch_training[1].SerializeToString(deterministic=True),
trainable=False)
predict_on_batch_inference = _make_concrete_flat_output_fn(
functools.partial(model.predict_on_batch, training=False), input_type)
m.predict_on_batch_inference = predict_on_batch_inference[0]
m.predict_on_batch_inference_type_spec = tf.Variable(
predict_on_batch_inference[1].SerializeToString(deterministic=True),
trainable=False)
# Serialize the report_local_outputs tf.function.
try:
m.report_local_outputs = model.report_local_outputs.get_concrete_function()
except NotImplementedError:
m.report_local_outputs = None
# Serialize the report_local_unfinalized_metrics tf.function.
m.report_local_unfinalized_metrics = (
model.report_local_unfinalized_metrics.get_concrete_function())
# Serialize the metric_finalizers as `tf.Variable`s.
m.serialized_metric_finalizers = collections.OrderedDict()
def serialize_metric_finalizer(finalizer, metric_type):
finalizer_computation = computations.tf_computation(finalizer, metric_type)
return tf.Variable(
computation_serialization.serialize_computation(
finalizer_computation).SerializeToString(deterministic=True),
trainable=False)
for metric_name, finalizer in model.metric_finalizers().items():
metric_type = type_conversions.type_from_tensors(
model.report_local_unfinalized_metrics()[metric_name])
m.serialized_metric_finalizers[metric_name] = serialize_metric_finalizer(
finalizer, metric_type)
# Serialize the TFF values as string variables that contain the serialized
# protos from the computation or the type.
m.serialized_input_spec = tf.Variable(
type_serialization.serialize_type(
computation_types.to_type(
model.input_spec)).SerializeToString(deterministic=True),
trainable=False)
try:
m.serialized_federated_output_computation = tf.Variable(
computation_serialization.serialize_computation(
model.federated_output_computation).SerializeToString(
deterministic=True),
trainable=False)
except NotImplementedError:
m.serialized_federated_output_computation = None
_save_tensorflow_module(m, path)