def getStats()

in baselines/acktr/kfac.py [0:0]


    def getStats(self, factors, varlist):
        if len(self.stats) == 0:
            # initialize stats variables on CPU because eigen decomp is
            # computed on CPU
            with tf.device('/cpu'):
                tmpStatsCache = {}

                # search for tensor factors and
                # use block diag approx for the bias units
                for var in varlist:
                    fpropFactor = factors[var]['fpropFactors_concat']
                    bpropFactor = factors[var]['bpropFactors_concat']
                    opType = factors[var]['opName']
                    if opType == 'Conv2D':
                        Kh = var.get_shape()[0]
                        Kw = var.get_shape()[1]
                        C = fpropFactor.get_shape()[-1]

                        Oh = bpropFactor.get_shape()[1]
                        Ow = bpropFactor.get_shape()[2]
                        if Oh == 1 and Ow == 1 and self._channel_fac:
                            # factorization along the channels do not support
                            # homogeneous coordinate
                            var_assnBias = factors[var]['assnBias']
                            if var_assnBias:
                                factors[var]['assnBias'] = None
                                factors[var_assnBias]['assnWeights'] = None
                ##

                for var in varlist:
                    fpropFactor = factors[var]['fpropFactors_concat']
                    bpropFactor = factors[var]['bpropFactors_concat']
                    opType = factors[var]['opName']
                    self.stats[var] = {'opName': opType,
                                       'fprop_concat_stats': [],
                                       'bprop_concat_stats': [],
                                       'assnWeights': factors[var]['assnWeights'],
                                       'assnBias': factors[var]['assnBias'],
                                       }
                    if fpropFactor is not None:
                        if fpropFactor not in tmpStatsCache:
                            if opType == 'Conv2D':
                                Kh = var.get_shape()[0]
                                Kw = var.get_shape()[1]
                                C = fpropFactor.get_shape()[-1]

                                Oh = bpropFactor.get_shape()[1]
                                Ow = bpropFactor.get_shape()[2]
                                if Oh == 1 and Ow == 1 and self._channel_fac:
                                    # factorization along the channels
                                    # assume independence between input channels and spatial
                                    # 2K-1 x 2K-1 covariance matrix and C x C covariance matrix
                                    # factorization along the channels do not
                                    # support homogeneous coordinate, assnBias
                                    # is always None
                                    fpropFactor2_size = Kh * Kw
                                    slot_fpropFactor_stats2 = tf.Variable(tf.diag(tf.ones(
                                        [fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
                                    self.stats[var]['fprop_concat_stats'].append(
                                        slot_fpropFactor_stats2)

                                    fpropFactor_size = C
                                else:
                                    # 2K-1 x 2K-1 x C x C covariance matrix
                                    # assume BHWC
                                    fpropFactor_size = Kh * Kw * C
                            else:
                                # D x D covariance matrix
                                fpropFactor_size = fpropFactor.get_shape()[-1]

                            # use homogeneous coordinate
                            if not self._blockdiag_bias and self.stats[var]['assnBias']:
                                fpropFactor_size += 1

                            slot_fpropFactor_stats = tf.Variable(tf.diag(tf.ones(
                                [fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False)
                            self.stats[var]['fprop_concat_stats'].append(
                                slot_fpropFactor_stats)
                            if opType != 'Conv2D':
                                tmpStatsCache[fpropFactor] = self.stats[
                                    var]['fprop_concat_stats']
                        else:
                            self.stats[var][
                                'fprop_concat_stats'] = tmpStatsCache[fpropFactor]

                    if bpropFactor is not None:
                        # no need to collect backward stats for bias vectors if
                        # using homogeneous coordinates
                        if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']):
                            if bpropFactor not in tmpStatsCache:
                                slot_bpropFactor_stats = tf.Variable(tf.diag(tf.ones([bpropFactor.get_shape(
                                )[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False)
                                self.stats[var]['bprop_concat_stats'].append(
                                    slot_bpropFactor_stats)
                                tmpStatsCache[bpropFactor] = self.stats[
                                    var]['bprop_concat_stats']
                            else:
                                self.stats[var][
                                    'bprop_concat_stats'] = tmpStatsCache[bpropFactor]

        return self.stats