def _monkey_patch_context()

in adanet/core/ensemble_builder.py [0:0]


def _monkey_patch_context(iteration_step_scope, scoped_summary, trainable_vars):
  """Monkey-patches global attributes with subnetwork-specifics ones."""

  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
  from tensorflow.python.training import training as train
  from tensorflow.python.training import training_util
  # pylint: enable=g-direct-tensorflow-import,g-import-not-at-top

  old_get_global_step_fn = tf_compat.v1.train.get_global_step
  old_get_or_create_global_step_fn = tf_compat.v1.train.get_or_create_global_step
  old_trainable_vars = tf_compat.v1.trainable_variables()

  def iteration_step(graph=None):
    graph = graph or tf_compat.v1.get_default_graph()
    with graph.as_default() as g, g.name_scope(None):
      with tf_compat.v1.variable_scope(
          iteration_step_scope, reuse=tf_compat.v1.AUTO_REUSE):
        return tf_compat.v1.get_variable(
            "iteration_step",
            shape=[],
            initializer=tf_compat.v1.zeros_initializer(),
            trainable=False,
            dtype=tf.int64)

  # monkey-patch global attributes.
  setattr(tf_compat.v1.train, "get_global_step", iteration_step)
  setattr(tf_compat.v1.train, "get_or_create_global_step", iteration_step)
  setattr(tf_v1.train, "get_global_step", iteration_step)
  setattr(tf_v1.train, "get_or_create_global_step", iteration_step)
  setattr(tf.train, "get_global_step", iteration_step)
  setattr(tf.train, "get_or_create_global_step", iteration_step)
  setattr(train, "get_global_step", iteration_step)
  setattr(training_util, "get_global_step", iteration_step)
  setattr(train, "get_or_create_global_step", iteration_step)
  setattr(training_util, "get_or_create_global_step", iteration_step)
  # The TPUEmbedding uses dummy variables to coordinate sending and receiving
  # gradients. If no gradients are computed on these dummy variables, the
  # TPUEmbedding will throw an error.
  embedding_variables = tf_compat.v1.get_collection(
      "tpu_embedding_dummy_table_variables")
  _set_trainable_variables(trainable_vars + embedding_variables)

  try:
    with monkey_patched_summaries(scoped_summary):
      yield
  finally:
    # Revert monkey-patches.
    new_trainable_vars = _get_current_vars(
        diffbase={"trainable": trainable_vars})["trainable"]
    _set_trainable_variables(old_trainable_vars + new_trainable_vars)
    setattr(training_util, "get_or_create_global_step",
            old_get_or_create_global_step_fn)
    setattr(train, "get_or_create_global_step",
            old_get_or_create_global_step_fn)
    setattr(training_util, "get_global_step", old_get_global_step_fn)
    setattr(train, "get_global_step", old_get_global_step_fn)
    setattr(tf.train, "get_or_create_global_step",
            old_get_or_create_global_step_fn)
    setattr(tf.train, "get_global_step", old_get_global_step_fn)
    setattr(tf_v1.train, "get_or_create_global_step",
            old_get_or_create_global_step_fn)
    setattr(tf_v1.train, "get_global_step", old_get_global_step_fn)
    setattr(tf_compat.v1.train, "get_or_create_global_step",
            old_get_or_create_global_step_fn)
    setattr(tf_compat.v1.train, "get_global_step", old_get_global_step_fn)