in baselines/acktr/kfac.py [0:0]
def getKfacPrecondUpdates(self, gradlist, varlist):
updatelist = []
vg = 0.
assert len(self.stats) > 0
assert len(self.stats_eigen) > 0
assert len(self.factors) > 0
counter = 0
grad_dict = {var: grad for grad, var in zip(gradlist, varlist)}
for grad, var in zip(gradlist, varlist):
GRAD_RESHAPE = False
GRAD_TRANSPOSE = False
fpropFactoredFishers = self.stats[var]['fprop_concat_stats']
bpropFactoredFishers = self.stats[var]['bprop_concat_stats']
if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0:
counter += 1
GRAD_SHAPE = grad.get_shape()
if len(grad.get_shape()) > 2:
# reshape conv kernel parameters
KW = int(grad.get_shape()[0])
KH = int(grad.get_shape()[1])
C = int(grad.get_shape()[2])
D = int(grad.get_shape()[3])
if len(fpropFactoredFishers) > 1 and self._channel_fac:
# reshape conv kernel parameters into tensor
grad = tf.reshape(grad, [KW * KH, C, D])
else:
# reshape conv kernel parameters into 2D grad
grad = tf.reshape(grad, [-1, D])
GRAD_RESHAPE = True
elif len(grad.get_shape()) == 1:
# reshape bias or 1D parameters
D = int(grad.get_shape()[0])
grad = tf.expand_dims(grad, 0)
GRAD_RESHAPE = True
else:
# 2D parameters
C = int(grad.get_shape()[0])
D = int(grad.get_shape()[1])
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
# use homogeneous coordinates only works for 2D grad.
# TO-DO: figure out how to factorize bias grad
# stack bias grad
var_assnBias = self.stats[var]['assnBias']
grad = tf.concat(
[grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0)
# project gradient to eigen space and reshape the eigenvalues
# for broadcasting
eigVals = []
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
e = detectMinVal(self.stats_eigen[stats][
'e'], var, name='act', debug=KFAC_DEBUG)
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act')
eigVals.append(e)
grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx)
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
e = detectMinVal(self.stats_eigen[stats][
'e'], var, name='grad', debug=KFAC_DEBUG)
Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad')
eigVals.append(e)
grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx)
##
#####
# whiten using eigenvalues
weightDecayCoeff = 0.
if var in self._weight_decay_dict:
weightDecayCoeff = self._weight_decay_dict[var]
if KFAC_DEBUG:
print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff)))
if self._factored_damping:
if KFAC_DEBUG:
print(('use factored damping for %s' % (var.name)))
coeffs = 1.
num_factors = len(eigVals)
# compute the ratio of two trace norm of the left and right
# KFac matrices, and their generalization
if len(eigVals) == 1:
damping = self._epsilon + weightDecayCoeff
else:
damping = tf.pow(
self._epsilon + weightDecayCoeff, 1. / num_factors)
eigVals_tnorm_avg = [tf.reduce_mean(
tf.abs(e)) for e in eigVals]
for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg):
eig_tnorm_negList = [
item for item in eigVals_tnorm_avg if item != e_tnorm]
if len(eigVals) == 1:
adjustment = 1.
elif len(eigVals) == 2:
adjustment = tf.sqrt(
e_tnorm / eig_tnorm_negList[0])
else:
eig_tnorm_negList_prod = reduce(
lambda x, y: x * y, eig_tnorm_negList)
adjustment = tf.pow(
tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors)
coeffs *= (e + adjustment * damping)
else:
coeffs = 1.
damping = (self._epsilon + weightDecayCoeff)
for e in eigVals:
coeffs *= e
coeffs += damping
#grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()])
grad /= coeffs
#grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()])
#####
# project gradient back to euclidean space
for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx)
for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']):
Q = self.stats_eigen[stats]['Q']
grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx)
##
#grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()])
if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias:
# use homogeneous coordinates only works for 2D grad.
# TO-DO: figure out how to factorize bias grad
# un-stack bias grad
var_assnBias = self.stats[var]['assnBias']
C_plus_one = int(grad.get_shape()[0])
grad_assnBias = tf.reshape(tf.slice(grad,
begin=[
C_plus_one - 1, 0],
size=[1, -1]), var_assnBias.get_shape())
grad_assnWeights = tf.slice(grad,
begin=[0, 0],
size=[C_plus_one - 1, -1])
grad_dict[var_assnBias] = grad_assnBias
grad = grad_assnWeights
#grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()])
if GRAD_RESHAPE:
grad = tf.reshape(grad, GRAD_SHAPE)
grad_dict[var] = grad
print(('projecting %d gradient matrices' % counter))
for g, var in zip(gradlist, varlist):
grad = grad_dict[var]
### clipping ###
if KFAC_DEBUG:
print(('apply clipping to %s' % (var.name)))
tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad")
local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr))
vg += local_vg
# recale everything
if KFAC_DEBUG:
print('apply vFv clipping')
scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg))
if KFAC_DEBUG:
scaling = tf.Print(scaling, [tf.convert_to_tensor(
'clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg])
with tf.control_dependencies([tf.assign(self.vFv, vg)]):
updatelist = [grad_dict[var] for var in varlist]
for i, item in enumerate(updatelist):
updatelist[i] = scaling * item
return updatelist