def compute_stats()

in baselines/acktr/kfac.py [0:0]


    def compute_stats(self, loss_sampled, var_list=None):
        varlist = var_list
        if varlist is None:
            varlist = tf.trainable_variables()

        gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled')
        self.gs = gs
        factors = self.getFactors(gs, varlist)
        stats = self.getStats(factors, varlist)

        updateOps = []
        statsUpdates = {}
        statsUpdates_cache = {}
        for var in varlist:
            opType = factors[var]['opName']
            fops = factors[var]['op']
            fpropFactor = factors[var]['fpropFactors_concat']
            fpropStats_vars = stats[var]['fprop_concat_stats']
            bpropFactor = factors[var]['bpropFactors_concat']
            bpropStats_vars = stats[var]['bprop_concat_stats']
            SVD_factors = {}
            for stats_var in fpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_fpropFactor = fpropFactor
                    B = (tf.shape(fpropFactor)[0])  # batch size
                    if opType == 'Conv2D':
                        strides = fops.get_attr("strides")
                        padding = fops.get_attr("padding")
                        convkernel_size = var.get_shape()[0:3]

                        KH = int(convkernel_size[0])
                        KW = int(convkernel_size[1])
                        C = int(convkernel_size[2])
                        flatten_size = int(KH * KW * C)

                        Oh = int(bpropFactor.get_shape()[1])
                        Ow = int(bpropFactor.get_shape()[2])

                        if Oh == 1 and Ow == 1 and self._channel_fac:
                                # factorization along the channels
                                # assume independence among input channels
                                # factor = B x 1 x 1 x (KH xKW x C)
                                # patches = B x Oh x Ow x (KH xKW x C)
                            if len(SVD_factors) == 0:
                                if KFAC_DEBUG:
                                    print(('approx %s act factor with rank-1 SVD factors' % (var.name)))
                                # find closest rank-1 approx to the feature map
                                S, U, V = tf.batch_svd(tf.reshape(
                                    fpropFactor, [-1, KH * KW, C]))
                                # get rank-1 approx slides
                                sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1)
                                patches_k = U[:, :, 0] * sqrtS1  # B x KH*KW
                                full_factor_shape = fpropFactor.get_shape()
                                patches_k.set_shape(
                                    [full_factor_shape[0], KH * KW])
                                patches_c = V[:, :, 0] * sqrtS1  # B x C
                                patches_c.set_shape([full_factor_shape[0], C])
                                SVD_factors[C] = patches_c
                                SVD_factors[KH * KW] = patches_k
                            fpropFactor = SVD_factors[stats_var_dim]

                        else:
                            # poor mem usage implementation
                            patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[
                                                               0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding)

                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 act fisher for %s' % (var.name)))
                                # T^2 terms * 1/T^2, size: B x C
                                fpropFactor = tf.reduce_mean(patches, [1, 2])
                            else:
                                # size: (B x Oh x Ow) x C
                                fpropFactor = tf.reshape(
                                    patches, [-1, flatten_size]) / Oh / Ow
                    fpropFactor_size = int(fpropFactor.get_shape()[-1])
                    if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias:
                        if opType == 'Conv2D' and not self._approxT2:
                            # correct padding for numerical stability (we
                            # divided out OhxOw from activations for T1 approx)
                            fpropFactor = tf.concat([fpropFactor, tf.ones(
                                [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1)
                        else:
                            # use homogeneous coordinates
                            fpropFactor = tf.concat(
                                [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1)

                    # average over the number of data points in a batch
                    # divided by B
                    cov = tf.matmul(fpropFactor, fpropFactor,
                                    transpose_a=True) / tf.cast(B, tf.float32)
                    updateOps.append(cov)
                    statsUpdates[stats_var] = cov
                    if opType != 'Conv2D':
                        # HACK: for convolution we recompute fprop stats for
                        # every layer including forking layers
                        statsUpdates_cache[stats_var] = cov

            for stats_var in bpropStats_vars:
                stats_var_dim = int(stats_var.get_shape()[0])
                if stats_var not in statsUpdates_cache:
                    old_bpropFactor = bpropFactor
                    bpropFactor_shape = bpropFactor.get_shape()
                    B = tf.shape(bpropFactor)[0]  # batch size
                    C = int(bpropFactor_shape[-1])  # num channels
                    if opType == 'Conv2D' or len(bpropFactor_shape) == 4:
                        if fpropFactor is not None:
                            if self._approxT2:
                                if KFAC_DEBUG:
                                    print(('approxT2 grad fisher for %s' % (var.name)))
                                bpropFactor = tf.reduce_sum(
                                    bpropFactor, [1, 2])  # T^2 terms * 1/T^2
                            else:
                                bpropFactor = tf.reshape(
                                    bpropFactor, [-1, C]) * Oh * Ow  # T * 1/T terms
                        else:
                            # just doing block diag approx. spatial independent
                            # structure does not apply here. summing over
                            # spatial locations
                            if KFAC_DEBUG:
                                print(('block diag approx fisher for %s' % (var.name)))
                            bpropFactor = tf.reduce_sum(bpropFactor, [1, 2])

                    # assume sampled loss is averaged. TO-DO:figure out better
                    # way to handle this
                    bpropFactor *= tf.to_float(B)
                    ##

                    cov_b = tf.matmul(
                        bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0])

                    updateOps.append(cov_b)
                    statsUpdates[stats_var] = cov_b
                    statsUpdates_cache[stats_var] = cov_b

        if KFAC_DEBUG:
            aKey = list(statsUpdates.keys())[0]
            statsUpdates[aKey] = tf.Print(statsUpdates[aKey],
                                          [tf.convert_to_tensor('step:'),
                                           self.global_step,
                                           tf.convert_to_tensor(
                                               'computing stats'),
                                           ])
        self.statsUpdates = statsUpdates
        return statsUpdates