def _serialize_assets()

in src/python/tensorflow_cloud/tuner/cloud_fit_client.py [0:0]


def _serialize_assets(remote_dir: Text,
                      model: tf.keras.Model,
                      **fit_kwargs) -> None:
    """Serialize Model and Dataset and store them in the local tmp folder.

    Args:
        remote_dir: A Google Cloud Storage path for assets and outputs
        model: A compiled Keras Model.
        **fit_kwargs: Args to pass to model.fit()

    Raises:
        NotImplementedError for callback functions and Generator input types.
    """
    to_export = tf.Module()
    # If x is instance of dataset or generators it needs to be serialized
    # differently.
    if "x" in fit_kwargs:
        if isinstance(fit_kwargs["x"], tf.data.Dataset):
            to_export.x = fit_kwargs.pop("x")
            x_fn = lambda: to_export.x
            to_export.x_fn = tf.function(x_fn, input_signature=())
        elif isinstance(fit_kwargs["x"], Generator):
            raise NotImplementedError("Generators are not currently supported!")
        logging.info("x was serialized successfully.")

    if "validation_data" in fit_kwargs and isinstance(
        fit_kwargs["validation_data"], tf.data.Dataset
    ):
        to_export.validation_data = fit_kwargs.pop("validation_data")
        validation_data_fn = lambda: to_export.validation_data
        to_export.validation_data_fn = tf.function(
            validation_data_fn, input_signature=()
        )
        logging.info("validation_data was serialized successfully.")

    callbacks = []
    if "callbacks" in fit_kwargs:
        callbacks = fit_kwargs.pop("callbacks")

    # The remote component does not save the model after training. To ensure the
    # model is saved after training completes we add a ModelCheckpoint callback,
    # if one is not provided by the user
    has_model_checkpoint = False
    for callback in callbacks:
        if issubclass(tf.keras.callbacks.ModelCheckpoint, callback.__class__):
            has_model_checkpoint = True

    if not has_model_checkpoint:
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(remote_dir, "checkpoint"),
            save_freq="epoch"))

    # Add all serializable callbacks to assets.
    to_export.callbacks = pickle.dumps(callbacks)
    callbacks_fn = lambda: to_export.callbacks
    to_export.callbacks_fn = tf.function(callbacks_fn, input_signature=())
    logging.info("callbacks were serialized successfully.")

    # All remaining items can be directly serialized as a dict.
    to_export.fit_kwargs = fit_kwargs
    fit_kwargs_fn = lambda: to_export.fit_kwargs
    to_export.fit_kwargs_fn = tf.function(fit_kwargs_fn, input_signature=())

    tf.saved_model.save(
        to_export, os.path.join(remote_dir, "training_assets"), signatures={}
    )

    # Saving the model
    model.save(os.path.join(remote_dir, "model"))