def __init__()

in kfac/python/ops/optimizer.py [0:0]


  def __init__(self,
               learning_rate,
               damping,
               layer_collection,
               cov_ema_decay=0.95,
               var_list=None,
               momentum=0.9,
               momentum_type="adam",
               use_weight_decay=False,
               weight_decay_coeff=0.1,
               qmodel_update_rescale=None,
               norm_constraint=None,
               name="KFAC",
               estimation_mode="gradients",
               colocate_gradients_with_ops=True,
               batch_size=None,
               placement_strategy=None,
               compute_params_stats=False,
               adapt_damping=False,
               update_damping_immediately=True,
               is_chief=True,
               prev_train_batch=None,
               loss=None,
               loss_fn=None,
               min_damping=1e-8,  # this value is somewhat arbitrary
               l2_reg=0.0,
               damping_adaptation_decay=0.95,
               damping_adaptation_interval=5,
               use_passed_loss=True,
               train_batch=None,
               print_logs=False,
               tf_replicator=None,
               dtype="float32",
               **kwargs):
    """Initializes the K-FAC optimizer with the given settings.

      NOTE: this is a base class for K-FAC optimizers that offers full control
      over the execution of K-FAC's various ops.  For a more fool-proof /
      automated version see for example PeriodicInvCovUpdateKfacOpt.

      Also, please keep in mind that while the K-FAC code loosely conforms to
      TensorFlow's Optimizer API it can't be used naively as a "drop in
      replacement" for basic classes like MomentumOptimizer.  Using it
      properly with SyncReplicasOptimizer, for example, requires special care.
      When using it with Distribution Strategy, unlike common practice, K-FAC
      expects a loss tensor that is normalized by the per-replica batch size,
      and *not* by the total batch size (like you may see in TF Distribution
      Strategy tutorials). Regardless of whether you are using estimator,
      strategy, or a normal custom training loop, you should pass in the same
      loss.

      See the various examples in the "examples" directory for a guide about
      how to use K-FAC in various contexts and various systems, like
      TF-Estimator. See in particular the "convnet" example.  google/examples
      also contains an example using TPUEstimator.

    Args:
      learning_rate: float or 0D Tensor. The base learning rate for the
          optimizer. Must be set to None if using one of the 'qmodel'
          momentum_type values.
      damping: float or 0D Tensor. This quantity times the identity matrix is
          (approximately) added to the curvature matrix (i.e. the Fisher or GGN)
          before it is inverted multiplied by the gradient when computing the
          (raw) update. This quantity should match the scale of the objective,
          so that if you put a multiplier on your loss you should apply the
          same multiplier to the damping. Roughly speaking, larger values
          constrain the update vector to a smaller region around zero, which
          we want to do when our local quadratic model is a less trustworthy
          local approximation of the true objective.  The damping value is
          closely related to the trust region radius and to the classical
          Tikhonov regularization method. If the `adapt_damping` argument is
          True then this value is used only as an initial value for the
          adaptation method.
      layer_collection: The layer collection object, which holds the Fisher
          blocks, Kronecker factors, and losses associated with the
          graph.  The layer_collection cannot be modified after KfacOptimizer's
          initialization.
      cov_ema_decay: The decay factor used when calculating the
          covariance estimate moving averages. (Default: 0.95)
      var_list: Optional list or tuple of variables to train. Defaults to
          tf.trainable_variables.
      momentum: The momentum decay constant to use. Only applies when
          momentum_type is 'regular' or 'adam'. (Default: 0.9)
      momentum_type: The type of momentum to use in this optimizer, one of
          'regular', 'adam', 'qmodel', or 'qmodel_fixedmu'. 'regular' gives
          standard momentum. 'adam' gives a style of momentum reminisent
          of the Adam method, which seems to work better in practice.
          'qmodel' makes the optimizer perform automatic control of both the
          learning rate and momentum using a quadratic model based method
          (see _compute_qmodel_hyperparams for more details). 'qmodel_fixedmu'
          is similar to 'qmodel' but only controls the learning rate.
          (Default: 'adam')
      use_weight_decay: If True, explicit "weight decay" is performed by K-FAC.
          Note that this is distinct from L2 regularization, and corresponds to
          optimizing a regularized version of the loss passed to minimize(),
          where the regularization term added is related to the "Fisher-Rao
          norm". See https://openreview.net/pdf?id=B1lz-3Rct7 for more details.
          Note that using this feature won't change the loss function you pass
          to minimize(), and thus the loss you report will not correspond
          precisely to what K-FAC is optimizing. (Default: False)
      weight_decay_coeff: The coefficient to use for weight decay (see above).
          (Default: 0.1)
      qmodel_update_rescale: float or None.  An additional multiplier to apply
          to the update computed by the quadratic model based adjustment
          methods. If None it will behave like a value of 1.0. (Default: None)
      norm_constraint: float or Tensor. If specified, the update is scaled down
          so that its approximate squared Fisher norm v^T F v is at most the
          specified value. May only be used with momentum type 'regular'.  See
          the docstring for the method _clip_updates() for a more detailed
          explanation of this feature. (Default: None)
      name: The name for this optimizer. (Default: 'KFAC')
      estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be
          'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',
          'exact', or 'exact_GGN'. See the doc-string for FisherEstimator
          (in estimator.py) for more a more detailed description of these
          options. (Default: 'gradients').
      colocate_gradients_with_ops: Whether we should request gradients we
          compute in the estimator be colocated with their respective ops.
          (Default: True)
      batch_size: The size of the mini-batch. Only needed when `momentum_type`
          == 'qmodel' or when `compute_params_stats` is True. Note that when
          using data parallelism where the model graph and optimizer are
          replicated across multiple devices, this should be the per-replica
          batch size. An example of this is sharded data on the TPU, where
          batch_size should be set to the total batch size divided by the
          number of shards. (Default: None)
      placement_strategy: string or None. Device placement strategy used when
          creating variables, and various ops. Can be None, 'round_robin', or
          'replica_round_robin'. 'round_robin' supports round-robin placement of
          various ops on lists of provided devices. 'replica_round_robin' does
          something similar but over shards/replicas instead, and only works
          in certain 'replicated' contexts (e.g. TPUEstimator).  The details of
          the different placement strategies are controlled by additional
          keyword arguments that can be passed to this class, and which are
          described in the different placement mixin classes in placement.py.
          (Default: None)
      compute_params_stats: Bool. If True, we compute the first order version
          of the statistics computed to estimate the Fisher/GGN. These
          correspond to the `variables` method in a one-to-one fashion.  They
          are available via the `params_stats` property.  When estimation_mode
          is 'empirical', this will correspond to the standard parameter
          gradient on the loss. (Default: False)
      adapt_damping: `Boolean`. If True we adapt the damping according to the
          Levenberg-Marquardt rule described in Section 6.5 of the original
          K-FAC paper. The details of this scheme are controlled by various
          additional arguments below. Also some of these arguments are extra
          pieces of information, such as the loss, needed by the method. Note
          that unless using a convenience subclass like
          PeriodicInvCovUpdateKfacOpt the damping adaptation op must be
          executed by the user (like the cov and inv ops). This op is returned
          by the maybe_pre_update_adapt_damping() method. (Default: False)
      update_damping_immediately: Damping adjustment strategy. If True then the
          damping is updated in the same optimizer minimize call as
          `(step+1) % damping_adaptation_interval == 0`, immediately after the
          parameter update is performed. If False then the damping is updated
          in the next step. If True then it is assumed that the apply_gradients
          op will safely update the model before returning; it is recommended
          to only resource variables in this case. (Default: True)
      is_chief: `Boolean`, `True` if the worker is chief. (Default: True)
      prev_train_batch: Training mini-batch used in the previous step. This
          will be used to evaluate loss by calling `loss_fn(prev_train_batch)`
          when damping adaptation is used. (Default: None)
      loss: `Tensor` the model loss, used as the pre-update loss in adaptive
          damping. Also used for the built-in log printing. When using
          Distribution Strategy, unlike common Distribution Strategy practice,
          this loss tensor should by normalized by the per-replica batch size
          and NOT the total batch size. (Default: None)
      loss_fn: `function` that takes as input training data tensor and returns
          a scalar loss. Only needed if using damping adaptation. When using
          Distribution Strategy, unlike common Distribution Strategy practice,
          the loss should by normalized by the per-replica batch size and NOT
          the total batch size. (Default: None)
      min_damping: `float`, Minimum value the damping parameter can take. Note
          that the default value of 1e-8 is quite arbitrary, and you may have
          to adjust this up or down for your particular problem. If you are
          using a non-zero value of l2_reg you *may* be able to set this to
          zero. (Default: 1e-8)
      l2_reg: `float` or 0D Tensor. Set this value to tell the optimizer what L2
          regularization coefficient you are using (if any). Note the
          coefficient appears in the regularizer as coeff / 2 * sum(param**2),
          as the thing you multiply tf.nn.l2(param) by. This will be essentially
          added to the minimum damping, but also included in the qmodel change
          computations (used for adjusting the damping) even when
          _INCLUDE_DAMPING_IN_QMODEL_CHANGE is False. Note that the user is
          still responsible for adding regularization to the loss.
          (Default: 0.0)
      damping_adaptation_decay: `float`, The `damping` parameter is
          multiplied by the `damping_adaptation_decay` every
          `damping_adaptation_interval` number of iterations. (Default: 0.99)
      damping_adaptation_interval: `int`, Number of steps in between
          updating the `damping` parameter. Note that damping is adapted at
          the step where (step+1) % damping_adaptation_interval == 0,
          (or immediately at the start of the next step by
          maybe_pre_update_adapt_damping() if update_damping_immediately is
          False). (Default: 5)
      use_passed_loss: `Boolean`. If True we use the loss tensor passed in by
          the user (via minimze() or compute_gradients() or the set_loss()
          method) in damping adaptation scheme, instead of calling loss_fn()
          a second time for this. This is more efficient but may not always be
          desired. (Default: True)
      train_batch: Training mini-batch used in the current step. This
          will be used to evaluate loss by calling `loss_fn(train_batch)`
          when damping adaptation is used and `use_passed_loss` is False.
          (Default: None)
      print_logs: `Boolean`. If True, we print some logging info using
          tf.print after each iteration. This is done in the method
          _maybe_print_logging_info, which we encourage you to modify in order
          to add whatever you want. (Default: False)
      tf_replicator: A Replicator object or None. If not None, K-FAC will set
          itself up to work inside of the provided TF-Replicator object.
          (Default: None)
      dtype: TF dtype or string representing one. dtype used for scalar
          properties (rho, etc). (Default: "float32")
      **kwargs: Arguments to be passed to specific placement strategy mixin.
          Check `placement.RoundRobinPlacementMixin` for example.

    Raises:
      ValueError: If the momentum type is unsupported.
      ValueError: If clipping is used with momentum type other than 'regular'.
      ValueError: If no losses have been registered with layer_collection.
      ValueError: If momentum is non-zero and momentum_type is not 'regular'
          or 'adam'.
    """
    dtype = tf.dtypes.as_dtype(dtype)

    self._layers = layer_collection

    self._colocate_gradients_with_ops = colocate_gradients_with_ops

    momentum_type = momentum_type.lower()
    legal_momentum_types = ["regular", "adam", "qmodel", "qmodel_fixedmu"]

    if momentum_type not in legal_momentum_types:
      raise ValueError("Unsupported momentum type {}. Must be one of {}."
                       .format(momentum_type, legal_momentum_types))
    if momentum_type not in ["regular", "adam"] and norm_constraint is not None:
      raise ValueError("Update clipping is only supported with momentum "
                       "type 'regular' and 'adam'.")
    if momentum_type == "qmodel" and momentum is not None:
      raise ValueError("Momentum must be None if using a momentum_type "
                       "'qmodel'.")
    self._momentum_type = momentum_type
    self._momentum = momentum

    self._use_weight_decay = use_weight_decay
    self._weight_decay_coeff = weight_decay_coeff

    self._norm_constraint = norm_constraint
    self._batch_size = batch_size
    self._placement_strategy = placement_strategy

    # Damping adaptation parameters
    self._adapt_damping = adapt_damping

    if self._adapt_damping:
      with tf.variable_scope(name):
        self._damping = tf.get_variable(
            "damping", initializer=damping, trainable=False,
            use_resource=True)
    else:
      self._damping = damping

    self._update_damping_immediately = update_damping_immediately
    self._is_chief = is_chief
    self._prev_train_batch = prev_train_batch
    self._loss_tensor = loss
    self._loss_fn = loss_fn
    self._damping_adaptation_decay = damping_adaptation_decay
    self._damping_adaptation_interval = damping_adaptation_interval
    self._omega = (
        self._damping_adaptation_decay**self._damping_adaptation_interval)
    self._min_damping = min_damping
    self._use_passed_loss = use_passed_loss
    if not use_passed_loss and train_batch is None:
      raise ValueError("Must pass in train_batch if used_passed_loss is false.")

    self._train_batch = train_batch

    self._print_logs = print_logs

    self._l2_reg = l2_reg

    if self._momentum_type.startswith("qmodel"):
      if learning_rate is not None:
        raise ValueError("'learning_rate' must be set to None if using one of "
                         "the 'qmodel' momentum types.")
      if qmodel_update_rescale is not None:
        learning_rate = qmodel_update_rescale
      else:
        learning_rate = 1.0
    else:
      if learning_rate is None:
        raise ValueError("'learning_rate' must *not* be set to None unless "
                         "using one of the 'qmodel' momentum types.")
      if qmodel_update_rescale is not None:
        raise ValueError("'qmodel_update_rescale' must be None unless using "
                         "one of the 'qmodel' momentum types.")
    self._qmodel_update_rescale = qmodel_update_rescale

    with tf.variable_scope(name):
      nan_init = lambda: tf.constant(float("nan"), dtype=dtype)
      # We store rho only for possible logging purposes.
      self._rho = tf.get_variable(
          "rho", initializer=nan_init, dtype=dtype,
          trainable=False, use_resource=True)
      self._prev_loss = tf.get_variable(
          "prev_loss", initializer=nan_init, dtype=dtype,
          trainable=False, use_resource=True)
      self._qmodel_learning_rate = tf.get_variable(
          "qmodel_learning_rate", initializer=nan_init, dtype=dtype,
          trainable=False, use_resource=True)
      self._qmodel_momentum = tf.get_variable(
          "qmodel_momentum", initializer=nan_init, dtype=dtype,
          trainable=False, use_resource=True)
      self._qmodel_change = tf.get_variable(
          "qmodel_change", initializer=nan_init, dtype=dtype,
          trainable=False, use_resource=True)

      self._counter = tf.get_variable(
          "counter", dtype=tf.int64, shape=(), trainable=False,
          initializer=tf.zeros_initializer, use_resource=True,
          aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)

      variables = var_list or tf.trainable_variables()

      if tf_replicator is not None or tf.distribute.has_strategy():
        def _get_sanitized_name(var_name):
          return re.sub(r"replica_\d+_", "", var_name)

        # This tells K-FAC's libraries that we are using TF-Replicator with this
        # particular Replicator object.
        utils.set_global_constants(tf_replicator=tf_replicator)

        # We need to sanitize the names of the variables that K-FAC creates
        # so they are the same between replicas.
        ff.set_global_constants(get_sanitized_name_fn=_get_sanitized_name)

      self._fisher_est = est.make_fisher_estimator(
          placement_strategy=placement_strategy,
          variables=variables,
          cov_ema_decay=cov_ema_decay,
          damping=self._damping,
          layer_collection=self.layers,
          exps=(-1,),
          estimation_mode=estimation_mode,
          colocate_gradients_with_ops=self._colocate_gradients_with_ops,
          compute_params_stats=compute_params_stats,
          batch_size=batch_size,
          **kwargs)

    super(KfacOptimizer, self).__init__(learning_rate, name=name)