in kfac/python/ops/estimator.py [0:0]
def __init__(self,
variables,
cov_ema_decay,
damping,
layer_collection,
exps=(-1,),
estimation_mode="gradients",
colocate_gradients_with_ops=True,
name="FisherEstimator",
compute_cholesky=False,
compute_cholesky_inverse=False,
compute_params_stats=False,
batch_size=None):
"""Create a FisherEstimator object.
Args:
variables: A `list` of variables for which to estimate the Fisher. This
must match the variables registered in layer_collection (if it is not
None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float or 0D Tensor. This quantity times the identity matrix is
(approximately) added to the matrix being estimated.
layer_collection: A LayerCollection object which holds for the
Fisher blocks, Kronecker factors, and losses associated with the
graph.
exps: List of floats or ints. These represent the different matrix
powers of the approximate Fisher that the FisherEstimator will be able
to multiply vectors by. If the user asks for a matrix power other
one of these (or 1, which is always supported), there will be a
failure. (Default: (-1,))
estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be
'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',
'exact', or 'exact_GGN'. (Default: 'gradients'). 'gradients' is the
basic estimation approach from the original K-FAC paper.
'empirical' computes the 'empirical' Fisher information matrix (which
uses the data's distribution for the targets, as opposed to the true
Fisher which uses the model's distribution) and requires that each
registered loss have specified targets. 'curvature_propagation' is a
method which estimates the Fisher using self-products of random 1/-1
vectors times "half-factors" of the Fisher, as described here:
https://arxiv.org/abs/1206.6464 . 'exact' is the obvious
generalization of Curvature Propagation to compute the exact Fisher
(modulo any additional diagonal or Kronecker approximations) by
looping over one-hot vectors for each coordinate of the output
instead of using 1/-1 vectors. It is more expensive to compute than
the other three options by a factor equal to the output dimension,
roughly speaking. Finally, 'curvature_prop_GGN' and 'exact_GGN' are
analogous to 'curvature_prop' and 'exact', but estimate the
Generalized Gauss-Newton matrix (GGN).
colocate_gradients_with_ops: Whether we should request gradients be
colocated with their respective ops. (Default: True)
name: A string. A name given to this estimator, which is added to the
variable scope when constructing variables and ops.
(Default: "FisherEstimator")
compute_cholesky: Bool. Whether or not the FisherEstimator will be
able to multiply vectors by the Cholesky factor.
(Default: False)
compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
will be able to multiply vectors by the Cholesky factor inverse.
(Default: False)
compute_params_stats: Bool. If True, we compute the first order version
of the statistics computed to estimate the Fisher/GGN. These correspond
to the `variables` method in a one-to-one fashion. They are available
via the `params_stats` property. When estimation_mode is 'empirical',
this will correspond to the standard parameter gradient on the loss.
(Default: False)
batch_size: The size of the mini-batch. Only needed when
`compute_params_stats` is True. Note that when using data parallelism
where the model graph and optimizer are replicated across multiple
devices, this should be the per-replica batch size. An example of
this is sharded data on the TPU, where batch_size should be set to
the total batch size divided by the number of shards. (Default: None)
Raises:
ValueError: If no losses have been registered with layer_collection.
"""
self._variables = variables
self._cov_ema_decay = cov_ema_decay
self._damping = damping
self._estimation_mode = estimation_mode
self._layer_collection = layer_collection
self._gradient_fns = {
"gradients": self._get_grads_lists_gradients,
"empirical": self._get_grads_lists_empirical,
"curvature_prop": self._get_grads_lists_curvature_prop,
"curvature_prop_GGN": self._get_grads_lists_curvature_prop,
"exact": self._get_grads_lists_exact,
"exact_GGN": self._get_grads_lists_exact
}
self._mat_type_table = {
"gradients": "Fisher",
"empirical": "Empirical_Fisher",
"curvature_prop": "Fisher",
"curvature_prop_GGN": "GGN",
"exact": "Fisher",
"exact_GGN": "GGN",
}
self._colocate_gradients_with_ops = colocate_gradients_with_ops
self._exps = exps
self._compute_cholesky = compute_cholesky
self._compute_cholesky_inverse = compute_cholesky_inverse
self._name = name
self._compute_params_stats = compute_params_stats
self._batch_size = batch_size
if compute_params_stats and batch_size is None:
raise ValueError("Batch size needs to be passed in when "
"compute_params_stats is True.")
self._finalized = False