in adanet/core/iteration.py [0:0]
def _create_tpu_train_op(self, base_global_step, subnetwork_specs, candidates,
mode, num_subnetworks, config):
"""Returns the train op for this set of candidates.
This train op combines the train ops from all the candidates into a single
train op. Additionally, it is responsible for incrementing the global step.
The train op is only non-None during the `TRAIN` mode.
Args:
base_global_step: Integer global step at the beginning of this iteration.
subnetwork_specs: List of `_SubnetworkSpec` instances for this iteration.
candidates: List of `_Candidate` instances to train.
mode: Defines whether this is training, evaluation or inference. The train
op is only non-None during `TRAIN`. See `ModeKeys`.
num_subnetworks: Integer number of subnetwork builders generated for the
current iteration.
config: The `tf.estimator.RunConfig` to use this iteration.
Returns:
A `Tensor` train op.
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return None
ensemble_specs = [c.ensemble_spec for c in candidates]
with tf_compat.v1.variable_scope("train_op"):
train_ops = []
if self._placement_strategy.should_train_subnetworks(num_subnetworks):
for subnetwork_spec in subnetwork_specs:
if subnetwork_spec.train_op is not None:
train_ops.append(subnetwork_spec.train_op.train_op)
for ensemble_spec in ensemble_specs:
if ensemble_spec.train_op is not None:
# The train op of a previous ensemble is None even during `TRAIN`.
train_ops.append(ensemble_spec.train_op.train_op)
with tf.control_dependencies(train_ops):
# Increment steps after train ops complete to avoid non-determinism.
increment_ops = [s.step.assign_add(1) for s in subnetwork_specs]
increment_ops += [e.step.assign_add(1) for e in ensemble_specs]
if not config.is_chief:
return tf.group(*increment_ops)
# AdaNet's chief worker is responsible for setting the global step, not
# the candidates it trains. Assigning the global step is the final
# action performed in the train op.
with tf.control_dependencies(increment_ops):
steps = [s.step.read_value() for s in subnetwork_specs]
global_step = tf_compat.v1.train.get_global_step()
return global_step.assign(
tf.cast(
base_global_step + self._global_step_combiner_fn(steps),
dtype=tf.int64))