def apply_gradients()

in easy_rec/python/compat/sync_replicas_optimizer.py [0:0]


  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
    if not grads_and_vars:
      raise ValueError('Must supply at least one variable')

    if global_step is None:
      raise ValueError('Global step is required to check staleness')

    self._global_step = global_step
    train_ops = []
    aggregated_grad = []
    var_list = []

    # local_anchor op will be placed on this worker task by default.
    local_anchor = control_flow_ops.no_op()
    # Colocating local_step variable prevents it being placed on the PS.
    with ops.colocate_with(local_anchor):
      self._local_step = variable_scope.variable(
          initial_value=0,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          dtype=global_step.dtype.base_dtype,
          name='sync_rep_local_step')

    self.local_step_init_op = state_ops.assign(self._local_step, global_step)
    chief_init_ops = [self.local_step_init_op]
    self.ready_for_local_init_op = variables.report_uninitialized_variables(
        variables.global_variables())

    with ops.name_scope(None, self._name):
      for grad, var in grads_and_vars:
        var_list.append(var)
        with ops.device(var.device):
          # Dense gradients.
          if grad is None:
            aggregated_grad.append(None)  # pass-through.
            continue
          elif isinstance(grad, ops.Tensor):
            grad_accum = data_flow_ops.ConditionalAccumulator(
                grad.dtype,
                shape=var.get_shape(),
                shared_name=var.name + '/grad_accum')
            train_ops.append(
                grad_accum.apply_grad(grad, local_step=self._local_step))
            aggregated_grad.append(
                grad_accum.take_grad(self._replicas_to_aggregate))
          else:
            if not isinstance(grad, ops.IndexedSlices):
              raise ValueError('Unknown grad type!')
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                grad.dtype, shape=(), shared_name=var.name + '/grad_accum')
            train_ops.append(
                grad_accum.apply_indexed_slices_grad(
                    grad, local_step=self._local_step))
            aggregated_grad.append(
                grad_accum.take_indexed_slices_grad(
                    self._replicas_to_aggregate))

          self._accumulator_list.append((grad_accum, var.device))

      aggregated_grads_and_vars = zip(aggregated_grad, var_list)

      # sync_op will be assigned to the same device as the global step.
      with ops.device(global_step.device), ops.name_scope(''):
        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
                                              global_step)

      def _get_token_qname():
        SyncReplicasOptimizer.sync_que_id += 1
        if SyncReplicasOptimizer.sync_que_id == 0:
          return 'sync_token_q'
        else:
          return 'sync_token_q_' + str(SyncReplicasOptimizer.sync_que_id)

      # Create token queue.
      token_qname = _get_token_qname()
      logging.info('create sync_token_queue[%s]' % token_qname)
      with ops.device(global_step.device), ops.name_scope(''):
        sync_token_queue = (
            data_flow_ops.FIFOQueue(
                -1,
                global_step.dtype.base_dtype,
                shapes=(),
                name=token_qname,
                shared_name=token_qname))
        self._sync_token_queue = sync_token_queue
        self._is_sync_que_closed = sync_token_queue.is_closed()
        self._close_sync_que = sync_token_queue.close(
            cancel_pending_enqueues=True, name='close_sync_token_queue')

        # dummy_queue is passed to the queue runner. Don't use the real queues
        # because the queue runner doesn't automatically reopen it once it
        # closed queues in PS devices.
        dummy_queue = (
            data_flow_ops.FIFOQueue(
                1,
                types_pb2.DT_INT32,
                shapes=(),
                name='dummy_queue',
                shared_name='dummy_queue'))

      with ops.device(global_step.device), ops.name_scope(''):
        # Replicas have to wait until they can get a token from the token queue.
        with ops.control_dependencies(train_ops):
          token = sync_token_queue.dequeue()
        train_op = state_ops.assign(self._local_step, token)

        with ops.control_dependencies([update_op]):
          # Sync_op needs to insert tokens to the token queue at the end of the
          # step so the replicas can fetch them to start the next step.
          tokens = array_ops.fill([self._tokens_per_step], global_step)
          sync_op = sync_token_queue.enqueue_many((tokens,))

        if self._variable_averages is not None:
          with ops.control_dependencies([sync_op]), ops.name_scope(''):
            sync_op = self._variable_averages.apply(self._variables_to_average)

        self._chief_queue_runner = queue_runner.QueueRunner(
            dummy_queue, [sync_op])
        ops.add_to_collection(ops.GraphKeys.QUEUE_RUNNERS,
                              self._chief_queue_runner)
      for accum, dev in self._accumulator_list:
        with ops.device(dev):
          chief_init_ops.append(
              accum.set_global_step(global_step, name='SetGlobalStep'))
      self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
      self._gradients_applied = True
      return train_op