def save()

in tensorflow_federated/python/learning/models/serialization.py [0:0]


def save(model: model_lib.Model, path: str, input_type=None) -> None:
  """Serializes `model` as a TensorFlow SavedModel to `path`.

  The resulting SavedModel will contain the default serving signature, which
  can be used with the TFLite converter to create a TFLite flatbuffer for
  inference.

  NOTE: The model returned by `tff.learning.models.load` will _not_ be the same
  Python type as the saved model. If the model serialized using this method is
  a subclass of `tff.learning.Model`, that subclass is _not_ returned. All
  method behavior is retained, but the Python type does not cross serialization
  boundaries. The return type of `metric_finalizers` will be an OrderedDict of
  str to `tff.tf_computation` (annotated TFF computations) which could be
  different from that of the model before serialization.

  Args:
    model: The `tff.learning.Model` to save.
    path: The `str` directory path to serialize the model to.
    input_type: An optional structure of `tf.TensorSpec`s representing the
      expected input of `model.predict_on_batch`, to override reading from
      `model.input_spec`. Typically this will be similar to `model.input_spec`,
      with any example labels removed. If None, default to
      `model.input_spec['x']` if the input_spec is a mapping, otherwise default
      to `model.input_spec[0]`.
  """
  py_typecheck.check_type(model, model_lib.Model)
  py_typecheck.check_type(path, str)
  if not path:
    raise ValueError('`path` must be a non-empty string, cannot serialize '
                     'models without an output path.')
  if isinstance(model, _LoadedSavedModel):
    # If we're saving a previously loaded model, we can simply use the module
    # already internal to the Model.
    _save_tensorflow_module(model._loaded_module, path)  # pylint: disable=protected-access
    return

  m = tf.Module()
  # We prefixed with `tff_` because `trainable_variables` is an attribute
  # reserved by `tf.Module`.
  m.tff_trainable_variables = model.trainable_variables
  m.tff_non_trainable_variables = model.non_trainable_variables
  m.tff_local_variables = model.local_variables
  # Serialize forward_pass. We must get two concrete versions of the
  # function, as the `training` argument is a Python value that changes the
  # graph computation. We serialize the output type so that we can repack the
  # flattened values after loaded the saved model.
  forward_pass_training = _make_concrete_flat_output_fn(
      functools.partial(model.forward_pass, training=True), model.input_spec)
  m.flat_forward_pass_training = forward_pass_training[0]
  m.forward_pass_training_type_spec = tf.Variable(
      forward_pass_training[1].SerializeToString(deterministic=True),
      trainable=False)

  forward_pass_inference = _make_concrete_flat_output_fn(
      functools.partial(model.forward_pass, training=False), model.input_spec)
  m.flat_forward_pass_inference = forward_pass_inference[0]
  m.forward_pass_inference_type_spec = tf.Variable(
      forward_pass_inference[1].SerializeToString(deterministic=True),
      trainable=False)
  # Get model prediction input type. If `None`, default to assuming the 'x' key
  # or first element of the model input spec is the input.
  if input_type is None:
    if isinstance(model.input_spec, collections.abc.Mapping):
      input_type = model.input_spec['x']
    else:
      input_type = model.input_spec[0]
  # Serialize predict_on_batch. We must get two concrete versions of the
  # function, as the `training` argument is a Python value that changes the
  # graph computation.
  predict_on_batch_training = _make_concrete_flat_output_fn(
      functools.partial(model.predict_on_batch, training=True), input_type)
  m.predict_on_batch_training = predict_on_batch_training[0]
  m.predict_on_batch_training_type_spec = tf.Variable(
      predict_on_batch_training[1].SerializeToString(deterministic=True),
      trainable=False)
  predict_on_batch_inference = _make_concrete_flat_output_fn(
      functools.partial(model.predict_on_batch, training=False), input_type)
  m.predict_on_batch_inference = predict_on_batch_inference[0]
  m.predict_on_batch_inference_type_spec = tf.Variable(
      predict_on_batch_inference[1].SerializeToString(deterministic=True),
      trainable=False)
  # Serialize the report_local_outputs tf.function.
  try:
    m.report_local_outputs = model.report_local_outputs.get_concrete_function()
  except NotImplementedError:
    m.report_local_outputs = None

  # Serialize the report_local_unfinalized_metrics tf.function.
  m.report_local_unfinalized_metrics = (
      model.report_local_unfinalized_metrics.get_concrete_function())

  # Serialize the metric_finalizers as `tf.Variable`s.
  m.serialized_metric_finalizers = collections.OrderedDict()

  def serialize_metric_finalizer(finalizer, metric_type):
    finalizer_computation = computations.tf_computation(finalizer, metric_type)
    return tf.Variable(
        computation_serialization.serialize_computation(
            finalizer_computation).SerializeToString(deterministic=True),
        trainable=False)

  for metric_name, finalizer in model.metric_finalizers().items():
    metric_type = type_conversions.type_from_tensors(
        model.report_local_unfinalized_metrics()[metric_name])
    m.serialized_metric_finalizers[metric_name] = serialize_metric_finalizer(
        finalizer, metric_type)

  # Serialize the TFF values as string variables that contain the serialized
  # protos from the computation or the type.
  m.serialized_input_spec = tf.Variable(
      type_serialization.serialize_type(
          computation_types.to_type(
              model.input_spec)).SerializeToString(deterministic=True),
      trainable=False)
  try:
    m.serialized_federated_output_computation = tf.Variable(
        computation_serialization.serialize_computation(
            model.federated_output_computation).SerializeToString(
                deterministic=True),
        trainable=False)
  except NotImplementedError:
    m.serialized_federated_output_computation = None

  _save_tensorflow_module(m, path)