in baselines/acktr/kfac.py [0:0]
def apply_gradients_kfac(self, grads):
g, varlist = list(zip(*grads))
if len(self.stats_eigen) == 0:
self.getStatsEigen()
qr = None
# launch eigen-decomp on a queue thread
if self._async:
print('Use async eigen decomp')
# get a list of factor loading tensors
factorOps_dummy = self.computeStatsEigen()
# define a queue for the list of factor loading tensors
queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[
item.get_shape() for item in factorOps_dummy])
enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(
0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op)
def dequeue_op():
return queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op])
updateOps = []
global_step_op = tf.assign_add(self.global_step, 1)
updateOps.append(global_step_op)
with tf.control_dependencies([global_step_op]):
# compute updates
assert self._update_stats_op != None
updateOps.append(self._update_stats_op)
dependency_list = []
if not self._async:
dependency_list.append(self._update_stats_op)
with tf.control_dependencies(dependency_list):
def no_op_wrapper():
return tf.group(*[tf.assign_add(self.cold_step, 1)])
if not self._async:
# synchronous eigen-decomp updates
updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update),
tf.convert_to_tensor(0)),
tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper)
else:
# asynchronous eigen-decomp updates using queue
updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter),
lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)),
tf.no_op,
lambda: tf.group(
*self.applyStatsEigen(dequeue_op())),
),
no_op_wrapper)
updateOps.append(updateFactorOps)
with tf.control_dependencies([updateFactorOps]):
def gradOp():
return list(g)
def getKfacGradOp():
return self.getKfacPrecondUpdates(g, varlist)
u = tf.cond(tf.greater(self.factor_step,
tf.convert_to_tensor(0)), getKfacGradOp, gradOp)
optim = tf.train.MomentumOptimizer(
self._lr * (1. - self._momentum), self._momentum)
#optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01)
def optimOp():
def updateOptimOp():
if self._full_stats_init:
return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op)
else:
return optim.apply_gradients(list(zip(u, varlist)))
if self._full_stats_init:
return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op)
else:
return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op)
updateOps.append(optimOp())
return tf.group(*updateOps), qr