in adanet/core/ensemble_builder.py [0:0]
def _monkey_patch_context(iteration_step_scope, scoped_summary, trainable_vars):
"""Monkey-patches global attributes with subnetwork-specifics ones."""
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
from tensorflow.python.training import training as train
from tensorflow.python.training import training_util
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
old_get_global_step_fn = tf_compat.v1.train.get_global_step
old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step
old_trainable_vars = tf_compat.v1.trainable_variables()
def iteration_step(graph=None):
graph = graph or tf_compat.v1.get_default_graph()
with graph.as_default() as g, g.name_scope(None):
with tf_compat.v1.variable_scope(
iteration_step_scope, reuse=tf_compat.v1.AUTO_REUSE):
return tf_compat.v1.get_variable(
"iteration_step",
shape=[],
initializer=tf_compat.v1.zeros_initializer(),
trainable=False,
dtype=tf.int64)
# monkey-patch global attributes.
setattr(tf_compat.v1.train, "get_global_step", iteration_step)
setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step)
setattr(tf_v1.train, "get_global_step", iteration_step)
setattr(tf_v1.train, "get_or_create_global_step", iteration_step)
setattr(tf.train, "get_global_step", iteration_step)
setattr(tf.train, "get_or_create_global_step", iteration_step)
setattr(train, "get_global_step", iteration_step)
setattr(training_util, "get_global_step", iteration_step)
setattr(train, "get_or_create_global_step", iteration_step)
setattr(training_util, "get_or_create_global_step", iteration_step)
# The TPUEmbedding uses dummy variables to coordinate sending and receiving
# gradients. If no gradients are computed on these dummy variables, the
# TPUEmbedding will throw an error.
embedding_variables = tf_compat.v1.get_collection(
"tpu_embedding_dummy_table_variables")
_set_trainable_variables(trainable_vars + embedding_variables)
try:
with monkey_patched_summaries(scoped_summary):
yield
finally:
# Revert monkey-patches.
new_trainable_vars = _get_current_vars(
diffbase={"trainable": trainable_vars})["trainable"]
_set_trainable_variables(old_trainable_vars + new_trainable_vars)
setattr(training_util, "get_or_create_global_step",
old_get_or_create_global_step_fn)
setattr(train, "get_or_create_global_step",
old_get_or_create_global_step_fn)
setattr(training_util, "get_global_step", old_get_global_step_fn)
setattr(train, "get_global_step", old_get_global_step_fn)
setattr(tf.train, "get_or_create_global_step",
old_get_or_create_global_step_fn)
setattr(tf.train, "get_global_step", old_get_global_step_fn)
setattr(tf_v1.train, "get_or_create_global_step",
old_get_or_create_global_step_fn)
setattr(tf_v1.train, "get_global_step", old_get_global_step_fn)
setattr(tf_compat.v1.train, "get_or_create_global_step",
old_get_or_create_global_step_fn)
setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)