def _create_tpu_train_op()

in adanet/core/iteration.py [0:0]


  def _create_tpu_train_op(self, base_global_step, subnetwork_specs, candidates,
                           mode, num_subnetworks, config):
    """Returns the train op for this set of candidates.

    This train op combines the train ops from all the candidates into a single
    train op. Additionally, it is responsible for incrementing the global step.

    The train op is only non-None during the `TRAIN` mode.

    Args:
      base_global_step: Integer global step at the beginning of this iteration.
      subnetwork_specs: List of `_SubnetworkSpec` instances for this iteration.
      candidates: List of `_Candidate` instances to train.
      mode: Defines whether this is training, evaluation or inference. The train
        op is only non-None during `TRAIN`. See `ModeKeys`.
      num_subnetworks: Integer number of subnetwork builders generated for the
        current iteration.
      config: The `tf.estimator.RunConfig` to use this iteration.

    Returns:
      A `Tensor` train op.
    """

    if mode != tf.estimator.ModeKeys.TRAIN:
      return None
    ensemble_specs = [c.ensemble_spec for c in candidates]
    with tf_compat.v1.variable_scope("train_op"):
      train_ops = []
      if self._placement_strategy.should_train_subnetworks(num_subnetworks):
        for subnetwork_spec in subnetwork_specs:
          if subnetwork_spec.train_op is not None:
            train_ops.append(subnetwork_spec.train_op.train_op)
      for ensemble_spec in ensemble_specs:
        if ensemble_spec.train_op is not None:
          # The train op of a previous ensemble is None even during `TRAIN`.
          train_ops.append(ensemble_spec.train_op.train_op)

      with tf.control_dependencies(train_ops):
        # Increment steps after train ops complete to avoid non-determinism.
        increment_ops = [s.step.assign_add(1) for s in subnetwork_specs]
        increment_ops += [e.step.assign_add(1) for e in ensemble_specs]

        if not config.is_chief:
          return tf.group(*increment_ops)
        # AdaNet's chief worker is responsible for setting the global step, not
        # the candidates it trains. Assigning the global step is the final
        # action performed in the train op.
        with tf.control_dependencies(increment_ops):
          steps = [s.step.read_value() for s in subnetwork_specs]
          global_step = tf_compat.v1.train.get_global_step()
          return global_step.assign(
              tf.cast(
                  base_global_step + self._global_step_combiner_fn(steps),
                  dtype=tf.int64))