in baselines/acktr/kfac.py [0:0]
def compute_stats(self, loss_sampled, var_list=None):
varlist = var_list
if varlist is None:
varlist = tf.trainable_variables()
gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
self.gs = gs
factors = self.getFactors(gs, varlist)
stats = self.getStats(factors, varlist)
updateOps = []
statsUpdates = {}
statsUpdates_cache = {}
for var in varlist:
opType = factors[var]['opName']
fops = factors[var]['op']
fpropFactor = factors[var]['fpropFactors_concat']
fpropStats_vars = stats[var]['fprop_concat_stats']
bpropFactor = factors[var]['bpropFactors_concat']
bpropStats_vars = stats[var]['bprop_concat_stats']
SVD_factors = {}
for stats_var in fpropStats_vars:
stats_var_dim = int(stats_var.get_shape()[0])
if stats_var not in statsUpdates_cache:
old_fpropFactor = fpropFactor
B = (tf.shape(fpropFactor)[0]) # batch size
if opType == 'Conv2D':
strides = fops.get_attr("strides")
padding = fops.get_attr("padding")
convkernel_size = var.get_shape()[0:3]
KH = int(convkernel_size[0])
KW = int(convkernel_size[1])
C = int(convkernel_size[2])
flatten_size = int(KH * KW * C)
Oh = int(bpropFactor.get_shape()[1])
Ow = int(bpropFactor.get_shape()[2])
if Oh == 1 and Ow == 1 and self._channel_fac:
# factorization along the channels
# assume independence among input channels
# factor = B x 1 x 1 x (KH xKW x C)
# patches = B x Oh x Ow x (KH xKW x C)
if len(SVD_factors) == 0:
if KFAC_DEBUG:
print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
# find closest rank-1 approx to the feature map
S, U, V = tf.batch_svd(tf.reshape(
fpropFactor, [-1, KH * KW, C]))
# get rank-1 approx slides
sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW
full_factor_shape = fpropFactor.get_shape()
patches_k.set_shape(
[full_factor_shape[0], KH * KW])
patches_c = V[:, :, 0] * sqrtS1 # B x C
patches_c.set_shape([full_factor_shape[0], C])
SVD_factors[C] = patches_c
SVD_factors[KH * KW] = patches_k
fpropFactor = SVD_factors[stats_var_dim]
else:
# poor mem usage implementation
patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)
if self._approxT2:
if KFAC_DEBUG:
print(('approxT2 act fisher for %s' % (var.name)))
# T^2 terms * 1/T^2, size: B x C
fpropFactor = tf.reduce_mean(patches, [1, 2])
else:
# size: (B x Oh x Ow) x C
fpropFactor = tf.reshape(
patches, [-1, flatten_size]) / Oh / Ow
fpropFactor_size = int(fpropFactor.get_shape()[-1])
if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
if opType == 'Conv2D' and not self._approxT2:
# correct padding for numerical stability (we
# divided out OhxOw from activations for T1 approx)
fpropFactor = tf.concat([fpropFactor, tf.ones(
[tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
else:
# use homogeneous coordinates
fpropFactor = tf.concat(
[fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)
# average over the number of data points in a batch
# divided by B
cov = tf.matmul(fpropFactor, fpropFactor,
transpose_a=True) / tf.cast(B, tf.float32)
updateOps.append(cov)
statsUpdates[stats_var] = cov
if opType != 'Conv2D':
# HACK: for convolution we recompute fprop stats for
# every layer including forking layers
statsUpdates_cache[stats_var] = cov
for stats_var in bpropStats_vars:
stats_var_dim = int(stats_var.get_shape()[0])
if stats_var not in statsUpdates_cache:
old_bpropFactor = bpropFactor
bpropFactor_shape = bpropFactor.get_shape()
B = tf.shape(bpropFactor)[0] # batch size
C = int(bpropFactor_shape[-1]) # num channels
if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
if fpropFactor is not None:
if self._approxT2:
if KFAC_DEBUG:
print(('approxT2 grad fisher for %s' % (var.name)))
bpropFactor = tf.reduce_sum(
bpropFactor, [1, 2]) # T^2 terms * 1/T^2
else:
bpropFactor = tf.reshape(
bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms
else:
# just doing block diag approx. spatial independent
# structure does not apply here. summing over
# spatial locations
if KFAC_DEBUG:
print(('block diag approx fisher for %s' % (var.name)))
bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])
# assume sampled loss is averaged. TO-DO:figure out better
# way to handle this
bpropFactor *= tf.to_float(B)
##
cov_b = tf.matmul(
bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])
updateOps.append(cov_b)
statsUpdates[stats_var] = cov_b
statsUpdates_cache[stats_var] = cov_b
if KFAC_DEBUG:
aKey = list(statsUpdates.keys())[0]
statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
[tf.convert_to_tensor('step:'),
self.global_step,
tf.convert_to_tensor(
'computing stats'),
])
self.statsUpdates = statsUpdates
return statsUpdates