in tensorflow_hub/keras_layer.py [0:0]
def _check_trainability(self):
"""Raises or logs errors for unuspported uses of trainable=True."""
if not self.trainable: return # Nothing to do.
# Training is only supported when calling a reusable TF2 SavedModel through
# its @tf.function __call__. Trying to train through a signature is likely
# to go wrong beyond the most simple cases due to a number of pitfalls:
# - No good support for train vs inference mode. TF1 Hub format used
# graph versions identified by tags, but this was not a general
# standard for SavedModels, and TF2 can no longer save with tags.
# - No support for update ops. TF1 Hub format had them in the UPDATE_OPS
# collection, but collections are no longer loaded in TF2. General
# SavedModel signatures had no support for them.
# - No support for regularization losses (same story).
# - A SavedModel without @tf.function __call__ will likely also not
# provide a trainable_variables attribute.
if self._is_hub_module_v1:
raise ValueError(
"Setting hub.KerasLayer.trainable = True is unsupported when "
"loading from the TF1 Hub format.")
elif self._signature:
raise ValueError(
"Setting hub.KerasLayer.trainable = True is unsupported when "
"calling a SavedModel signature.")
# Having zero trainable variables in an otherwise trainable model
# is suspicious but may be valid as a boundary case, so we just log,
# but at most once per layer instance.
if not self.trainable_weights:
if not hasattr(self, "_already_logged_trainable_with_zero_weights"):
logging.error(
"hub.KerasLayer is trainable but has zero trainable weights.")
setattr(self, "_already_logged_trainable_with_zero_weights", True)