in kfac/python/ops/layer_collection.py [0:0]
def __init__(self,
graph=None,
name="LayerCollection"):
self.fisher_blocks = LayerParametersDict()
self.fisher_factors = OrderedDict()
self._linked_parameters = dict(
) # dict mapping sets of variables to optionally specified approximations.
self._graph = graph or tf.get_default_graph()
self._loss_dict = OrderedDict() # {str: LossFunction}
self._subgraph = None
self._default_generic_approximation = APPROX_DIAGONAL_NAME
self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
self._default_conv2d_approximation = APPROX_KRONECKER_NAME
self._default_fully_connected_multi_approximation = (
APPROX_KRONECKER_INDEP_NAME)
self._default_conv2d_multi_approximation = (
APPROX_KRONECKER_INDEP_NAME)
self._default_scale_and_shift_approximation = APPROX_FULL_NAME
self.loss_colocation_ops = {}
self.loss_coeffs = {}
self._vars_to_uses = defaultdict(lambda: 0)
self._finalized = False
with tf.variable_scope(None, default_name=name) as scope:
self._var_scope = scope.name
self._generic_approx_to_block_types = {
APPROX_FULL_NAME: fb.NaiveFullFB,
APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
}
self._fully_connected_approx_to_block_types = {
APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
APPROX_KRONECKER_IN_DIAG_NAME:
partial(fb.FullyConnectedKFACBasicFB,
diagonal_approx_for_input=True),
APPROX_KRONECKER_OUT_DIAG_NAME:
partial(fb.FullyConnectedKFACBasicFB,
diagonal_approx_for_output=True),
APPROX_KRONECKER_BOTH_DIAG_NAME:
partial(fb.FullyConnectedKFACBasicFB,
diagonal_approx_for_input=True,
diagonal_approx_for_output=True),
APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
}
self._conv2d_approx_to_block_types = {
APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
APPROX_KRONECKER_SUA_NAME: fb.ConvKFCBasicFB,
}
self._fully_connected_multi_approx_to_block_types = {
APPROX_KRONECKER_INDEP_NAME:
fb.FullyConnectedMultiIndepFB,
APPROX_KRONECKER_INDEP_IN_DIAG_NAME:
partial(fb.FullyConnectedMultiIndepFB,
diagonal_approx_for_input=True),
APPROX_KRONECKER_INDEP_OUT_DIAG_NAME:
partial(fb.FullyConnectedMultiIndepFB,
diagonal_approx_for_output=True),
APPROX_KRONECKER_INDEP_BOTH_DIAG_NAME:
partial(fb.FullyConnectedMultiIndepFB,
diagonal_approx_for_input=True,
diagonal_approx_for_output=True),
APPROX_KRONECKER_SERIES_1_NAME:
partial(fb.FullyConnectedSeriesFB, option=1),
APPROX_KRONECKER_SERIES_2_NAME:
partial(fb.FullyConnectedSeriesFB, option=2)
}
self._conv2d_multi_approx_to_block_types = {
APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
}
self._scale_and_shift_approx_to_block_types = {
APPROX_FULL_NAME: fb.ScaleAndShiftFullFB,
APPROX_DIAGONAL_NAME: fb.ScaleAndShiftDiagonalFB
}