def __init__()

in kfac/python/ops/estimator.py [0:0]


  def __init__(self,
               variables,
               cov_ema_decay,
               damping,
               layer_collection,
               exps=(-1,),
               estimation_mode="gradients",
               colocate_gradients_with_ops=True,
               name="FisherEstimator",
               compute_cholesky=False,
               compute_cholesky_inverse=False,
               compute_params_stats=False,
               batch_size=None):
    """Create a FisherEstimator object.

    Args:
      variables: A `list` of variables for which to estimate the Fisher. This
        must match the variables registered in layer_collection (if it is not
        None).
      cov_ema_decay: The decay factor used when calculating the covariance
        estimate moving averages.
      damping: float or 0D Tensor. This quantity times the identity matrix is
        (approximately) added to the matrix being estimated.
      layer_collection: A LayerCollection object which holds for the
        Fisher blocks, Kronecker factors, and losses associated with the
        graph.
      exps: List of floats or ints. These represent the different matrix
        powers of the approximate Fisher that the FisherEstimator will be able
        to multiply vectors by. If the user asks for a matrix power other
        one of these (or 1, which is always supported), there will be a
        failure. (Default: (-1,))
      estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be
        'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',
        'exact', or 'exact_GGN'. (Default: 'gradients'). 'gradients' is the
        basic estimation approach from the original K-FAC paper.
        'empirical' computes the 'empirical' Fisher information matrix (which
        uses the data's distribution for the targets, as opposed to the true
        Fisher which uses the model's distribution) and requires that each
        registered loss have specified targets. 'curvature_propagation' is a
        method which estimates the Fisher using self-products of random 1/-1
        vectors times "half-factors" of the Fisher, as described here:
        https://arxiv.org/abs/1206.6464 . 'exact' is the obvious
        generalization of Curvature Propagation to compute the exact Fisher
        (modulo any additional diagonal or Kronecker approximations) by
        looping over one-hot vectors for each coordinate of the output
        instead of using 1/-1 vectors.  It is more expensive to compute than
        the other three options by a factor equal to the output dimension,
        roughly speaking. Finally, 'curvature_prop_GGN' and 'exact_GGN' are
        analogous to 'curvature_prop' and 'exact', but estimate the
        Generalized Gauss-Newton matrix (GGN).
      colocate_gradients_with_ops: Whether we should request gradients be
        colocated with their respective ops. (Default: True)
      name: A string. A name given to this estimator, which is added to the
        variable scope when constructing variables and ops.
        (Default: "FisherEstimator")
      compute_cholesky: Bool. Whether or not the FisherEstimator will be
        able to multiply vectors by the Cholesky factor.
        (Default: False)
      compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
        will be able to multiply vectors by the Cholesky factor inverse.
        (Default: False)
      compute_params_stats: Bool. If True, we compute the first order version
        of the statistics computed to estimate the Fisher/GGN. These correspond
        to the `variables` method in a one-to-one fashion.  They are available
        via the `params_stats` property.  When estimation_mode is 'empirical',
        this will correspond to the standard parameter gradient on the loss.
        (Default: False)
      batch_size: The size of the mini-batch. Only needed when
        `compute_params_stats` is True. Note that when using data parallelism
        where the model graph and optimizer are replicated across multiple
        devices, this should be the per-replica batch size. An example of
        this is sharded data on the TPU, where batch_size should be set to
        the total batch size divided by the number of shards. (Default: None)

    Raises:
      ValueError: If no losses have been registered with layer_collection.
    """
    self._variables = variables
    self._cov_ema_decay = cov_ema_decay
    self._damping = damping
    self._estimation_mode = estimation_mode
    self._layer_collection = layer_collection
    self._gradient_fns = {
        "gradients": self._get_grads_lists_gradients,
        "empirical": self._get_grads_lists_empirical,
        "curvature_prop": self._get_grads_lists_curvature_prop,
        "curvature_prop_GGN": self._get_grads_lists_curvature_prop,
        "exact": self._get_grads_lists_exact,
        "exact_GGN": self._get_grads_lists_exact
    }
    self._mat_type_table = {
        "gradients": "Fisher",
        "empirical": "Empirical_Fisher",
        "curvature_prop": "Fisher",
        "curvature_prop_GGN": "GGN",
        "exact": "Fisher",
        "exact_GGN": "GGN",
    }

    self._colocate_gradients_with_ops = colocate_gradients_with_ops

    self._exps = exps
    self._compute_cholesky = compute_cholesky
    self._compute_cholesky_inverse = compute_cholesky_inverse

    self._name = name

    self._compute_params_stats = compute_params_stats
    self._batch_size = batch_size

    if compute_params_stats and batch_size is None:
      raise ValueError("Batch size needs to be passed in when "
                       "compute_params_stats is True.")

    self._finalized = False