def getFactors()

in baselines/acktr/kfac.py [0:0]


    def getFactors(self, g, varlist):
        graph = tf.get_default_graph()
        factorTensors = {}
        fpropTensors = []
        bpropTensors = []
        opTypes = []
        fops = []

        def searchFactors(gradient, graph):
            # hard coded search stratergy
            bpropOp = gradient.op
            bpropOp_name = bpropOp.name

            bTensors = []
            fTensors = []

            # combining additive gradient, assume they are the same op type and
            # indepedent
            if 'AddN' in bpropOp_name:
                factors = []
                for g in gradient.op.inputs:
                    factors.append(searchFactors(g, graph))
                op_names = [item['opName'] for item in factors]
                # TO-DO: need to check all the attribute of the ops as well
                print (gradient.name)
                print (op_names)
                print (len(np.unique(op_names)))
                assert len(np.unique(op_names)) == 1, gradient.name + \
                    ' is shared among different computation OPs'

                bTensors = reduce(lambda x, y: x + y,
                                  [item['bpropFactors'] for item in factors])
                if len(factors[0]['fpropFactors']) > 0:
                    fTensors = reduce(
                        lambda x, y: x + y, [item['fpropFactors'] for item in factors])
                fpropOp_name = op_names[0]
                fpropOp = factors[0]['op']
            else:
                fpropOp_name = re.search(
                    'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2)
                fpropOp = graph.get_operation_by_name(fpropOp_name)
                if fpropOp.op_def.name in KFAC_OPS:
                    # Known OPs
                    ###
                    bTensor = [
                        i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1]
                    bTensorShape = fpropOp.outputs[0].get_shape()
                    if bTensor.get_shape()[0].value == None:
                        bTensor.set_shape(bTensorShape)
                    bTensors.append(bTensor)
                    ###
                    if fpropOp.op_def.name == 'BiasAdd':
                        fTensors = []
                    else:
                        fTensors.append(
                            [i for i in fpropOp.inputs if param.op.name not in i.name][0])
                    fpropOp_name = fpropOp.op_def.name
                else:
                    # unknown OPs, block approximation used
                    bInputsList = [i for i in bpropOp.inputs[
                        0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name]
                    if len(bInputsList) > 0:
                        bTensor = bInputsList[0]
                        bTensorShape = fpropOp.outputs[0].get_shape()
                        if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None:
                            bTensor.set_shape(bTensorShape)
                        bTensors.append(bTensor)
                    fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name)

            return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors}

        for t, param in zip(g, varlist):
            if KFAC_DEBUG:
                print(('get factor for '+param.name))
            factors = searchFactors(t, graph)
            factorTensors[param] = factors

        ########
        # check associated weights and bias for homogeneous coordinate representation
        # and check redundent factors
        # TO-DO: there may be a bug to detect associate bias and weights for
        # forking layer, e.g. in inception models.
        for param in varlist:
            factorTensors[param]['assnWeights'] = None
            factorTensors[param]['assnBias'] = None
        for param in varlist:
            if factorTensors[param]['opName'] == 'BiasAdd':
                factorTensors[param]['assnWeights'] = None
                for item in varlist:
                    if len(factorTensors[item]['bpropFactors']) > 0:
                        if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0):
                            factorTensors[param]['assnWeights'] = item
                            factorTensors[item]['assnBias'] = param
                            factorTensors[param]['bpropFactors'] = factorTensors[
                                item]['bpropFactors']

        ########

        ########
        # concatenate the additive gradients along the batch dimension, i.e.
        # assuming independence structure
        for key in ['fpropFactors', 'bpropFactors']:
            for i, param in enumerate(varlist):
                if len(factorTensors[param][key]) > 0:
                    if (key + '_concat') not in factorTensors[param]:
                        name_scope = factorTensors[param][key][0].name.split(':')[
                            0]
                        with tf.name_scope(name_scope):
                            factorTensors[param][
                                key + '_concat'] = tf.concat(factorTensors[param][key], 0)
                else:
                    factorTensors[param][key + '_concat'] = None
                for j, param2 in enumerate(varlist[(i + 1):]):
                    if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])):
                        factorTensors[param2][key] = factorTensors[param][key]
                        factorTensors[param2][
                            key + '_concat'] = factorTensors[param][key + '_concat']
        ########

        if KFAC_DEBUG:
            for items in zip(varlist, fpropTensors, bpropTensors, opTypes):
                print((items[0].name, factorTensors[item]))
        self.factors = factorTensors
        return factorTensors