def apply_gradients_kfac()

in baselines/acktr/kfac.py [0:0]


    def apply_gradients_kfac(self, grads):
        g, varlist = list(zip(*grads))

        if len(self.stats_eigen) == 0:
            self.getStatsEigen()

        qr = None
        # launch eigen-decomp on a queue thread
        if self._async:
            print('Use async eigen decomp')
            # get a list of factor loading tensors
            factorOps_dummy = self.computeStatsEigen()

            # define a queue for the list of factor loading tensors
            queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
                                 item.get_shape() for item in factorOps_dummy])
            enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
                0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)

            def dequeue_op():
                return queue.dequeue()

            qr = tf.train.QueueRunner(queue, [enqueue_op])

        updateOps = []
        global_step_op = tf.assign_add(self.global_step, 1)
        updateOps.append(global_step_op)

        with tf.control_dependencies([global_step_op]):

            # compute updates
            assert self._update_stats_op != None
            updateOps.append(self._update_stats_op)
            dependency_list = []
            if not self._async:
                dependency_list.append(self._update_stats_op)

            with tf.control_dependencies(dependency_list):
                def no_op_wrapper():
                    return tf.group(*[tf.assign_add(self.cold_step, 1)])

                if not self._async:
                    # synchronous eigen-decomp updates
                    updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update),
                                                                      tf.convert_to_tensor(0)),
                                                             tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper)
                else:
                    # asynchronous eigen-decomp updates using queue
                    updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter),
                                              lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)),
                                                              tf.no_op,

                                                              lambda: tf.group(
                                                                  *self.applyStatsEigen(dequeue_op())),
                                                              ),
                                              no_op_wrapper)

                updateOps.append(updateFactorOps)

                with tf.control_dependencies([updateFactorOps]):
                    def gradOp():
                        return list(g)

                    def getKfacGradOp():
                        return self.getKfacPrecondUpdates(g, varlist)
                    u = tf.cond(tf.greater(self.factor_step,
                                           tf.convert_to_tensor(0)), getKfacGradOp, gradOp)

                    optim = tf.train.MomentumOptimizer(
                        self._lr * (1. - self._momentum), self._momentum)
                    #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)

                    def optimOp():
                        def updateOptimOp():
                            if self._full_stats_init:
                                return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op)
                            else:
                                return optim.apply_gradients(list(zip(u, varlist)))
                        if self._full_stats_init:
                            return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op)
                        else:
                            return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op)
                    updateOps.append(optimOp())

        return tf.group(*updateOps), qr