in src/python/tensorflow_cloud/tuner/cloud_fit_client.py [0:0]
def _serialize_assets(remote_dir: Text,
model: tf.keras.Model,
**fit_kwargs) -> None:
"""Serialize Model and Dataset and store them in the local tmp folder.
Args:
remote_dir: A Google Cloud Storage path for assets and outputs
model: A compiled Keras Model.
**fit_kwargs: Args to pass to model.fit()
Raises:
NotImplementedError for callback functions and Generator input types.
"""
to_export = tf.Module()
# If x is instance of dataset or generators it needs to be serialized
# differently.
if "x" in fit_kwargs:
if isinstance(fit_kwargs["x"], tf.data.Dataset):
to_export.x = fit_kwargs.pop("x")
x_fn = lambda: to_export.x
to_export.x_fn = tf.function(x_fn, input_signature=())
elif isinstance(fit_kwargs["x"], Generator):
raise NotImplementedError("Generators are not currently supported!")
logging.info("x was serialized successfully.")
if "validation_data" in fit_kwargs and isinstance(
fit_kwargs["validation_data"], tf.data.Dataset
):
to_export.validation_data = fit_kwargs.pop("validation_data")
validation_data_fn = lambda: to_export.validation_data
to_export.validation_data_fn = tf.function(
validation_data_fn, input_signature=()
)
logging.info("validation_data was serialized successfully.")
callbacks = []
if "callbacks" in fit_kwargs:
callbacks = fit_kwargs.pop("callbacks")
# The remote component does not save the model after training. To ensure the
# model is saved after training completes we add a ModelCheckpoint callback,
# if one is not provided by the user
has_model_checkpoint = False
for callback in callbacks:
if issubclass(tf.keras.callbacks.ModelCheckpoint, callback.__class__):
has_model_checkpoint = True
if not has_model_checkpoint:
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(remote_dir, "checkpoint"),
save_freq="epoch"))
# Add all serializable callbacks to assets.
to_export.callbacks = pickle.dumps(callbacks)
callbacks_fn = lambda: to_export.callbacks
to_export.callbacks_fn = tf.function(callbacks_fn, input_signature=())
logging.info("callbacks were serialized successfully.")
# All remaining items can be directly serialized as a dict.
to_export.fit_kwargs = fit_kwargs
fit_kwargs_fn = lambda: to_export.fit_kwargs
to_export.fit_kwargs_fn = tf.function(fit_kwargs_fn, input_signature=())
tf.saved_model.save(
to_export, os.path.join(remote_dir, "training_assets"), signatures={}
)
# Saving the model
model.save(os.path.join(remote_dir, "model"))