in kfac/python/ops/optimizer.py [0:0]
def __init__(self,
learning_rate,
damping,
layer_collection,
cov_ema_decay=0.95,
var_list=None,
momentum=0.9,
momentum_type="adam",
use_weight_decay=False,
weight_decay_coeff=0.1,
qmodel_update_rescale=None,
norm_constraint=None,
name="KFAC",
estimation_mode="gradients",
colocate_gradients_with_ops=True,
batch_size=None,
placement_strategy=None,
compute_params_stats=False,
adapt_damping=False,
update_damping_immediately=True,
is_chief=True,
prev_train_batch=None,
loss=None,
loss_fn=None,
min_damping=1e-8, # this value is somewhat arbitrary
l2_reg=0.0,
damping_adaptation_decay=0.95,
damping_adaptation_interval=5,
use_passed_loss=True,
train_batch=None,
print_logs=False,
tf_replicator=None,
dtype="float32",
**kwargs):
"""Initializes the K-FAC optimizer with the given settings.
NOTE: this is a base class for K-FAC optimizers that offers full control
over the execution of K-FAC's various ops. For a more fool-proof /
automated version see for example PeriodicInvCovUpdateKfacOpt.
Also, please keep in mind that while the K-FAC code loosely conforms to
TensorFlow's Optimizer API it can't be used naively as a "drop in
replacement" for basic classes like MomentumOptimizer. Using it
properly with SyncReplicasOptimizer, for example, requires special care.
When using it with Distribution Strategy, unlike common practice, K-FAC
expects a loss tensor that is normalized by the per-replica batch size,
and *not* by the total batch size (like you may see in TF Distribution
Strategy tutorials). Regardless of whether you are using estimator,
strategy, or a normal custom training loop, you should pass in the same
loss.
See the various examples in the "examples" directory for a guide about
how to use K-FAC in various contexts and various systems, like
TF-Estimator. See in particular the "convnet" example. google/examples
also contains an example using TPUEstimator.
Args:
learning_rate: float or 0D Tensor. The base learning rate for the
optimizer. Must be set to None if using one of the 'qmodel'
momentum_type values.
damping: float or 0D Tensor. This quantity times the identity matrix is
(approximately) added to the curvature matrix (i.e. the Fisher or GGN)
before it is inverted multiplied by the gradient when computing the
(raw) update. This quantity should match the scale of the objective,
so that if you put a multiplier on your loss you should apply the
same multiplier to the damping. Roughly speaking, larger values
constrain the update vector to a smaller region around zero, which
we want to do when our local quadratic model is a less trustworthy
local approximation of the true objective. The damping value is
closely related to the trust region radius and to the classical
Tikhonov regularization method. If the `adapt_damping` argument is
True then this value is used only as an initial value for the
adaptation method.
layer_collection: The layer collection object, which holds the Fisher
blocks, Kronecker factors, and losses associated with the
graph. The layer_collection cannot be modified after KfacOptimizer's
initialization.
cov_ema_decay: The decay factor used when calculating the
covariance estimate moving averages. (Default: 0.95)
var_list: Optional list or tuple of variables to train. Defaults to
tf.trainable_variables.
momentum: The momentum decay constant to use. Only applies when
momentum_type is 'regular' or 'adam'. (Default: 0.9)
momentum_type: The type of momentum to use in this optimizer, one of
'regular', 'adam', 'qmodel', or 'qmodel_fixedmu'. 'regular' gives
standard momentum. 'adam' gives a style of momentum reminisent
of the Adam method, which seems to work better in practice.
'qmodel' makes the optimizer perform automatic control of both the
learning rate and momentum using a quadratic model based method
(see _compute_qmodel_hyperparams for more details). 'qmodel_fixedmu'
is similar to 'qmodel' but only controls the learning rate.
(Default: 'adam')
use_weight_decay: If True, explicit "weight decay" is performed by K-FAC.
Note that this is distinct from L2 regularization, and corresponds to
optimizing a regularized version of the loss passed to minimize(),
where the regularization term added is related to the "Fisher-Rao
norm". See https://openreview.net/pdf?id=B1lz-3Rct7 for more details.
Note that using this feature won't change the loss function you pass
to minimize(), and thus the loss you report will not correspond
precisely to what K-FAC is optimizing. (Default: False)
weight_decay_coeff: The coefficient to use for weight decay (see above).
(Default: 0.1)
qmodel_update_rescale: float or None. An additional multiplier to apply
to the update computed by the quadratic model based adjustment
methods. If None it will behave like a value of 1.0. (Default: None)
norm_constraint: float or Tensor. If specified, the update is scaled down
so that its approximate squared Fisher norm v^T F v is at most the
specified value. May only be used with momentum type 'regular'. See
the docstring for the method _clip_updates() for a more detailed
explanation of this feature. (Default: None)
name: The name for this optimizer. (Default: 'KFAC')
estimation_mode: The type of estimator to use for the Fishers/GGNs. Can be
'gradients', 'empirical', 'curvature_prop', 'curvature_prop_GGN',
'exact', or 'exact_GGN'. See the doc-string for FisherEstimator
(in estimator.py) for more a more detailed description of these
options. (Default: 'gradients').
colocate_gradients_with_ops: Whether we should request gradients we
compute in the estimator be colocated with their respective ops.
(Default: True)
batch_size: The size of the mini-batch. Only needed when `momentum_type`
== 'qmodel' or when `compute_params_stats` is True. Note that when
using data parallelism where the model graph and optimizer are
replicated across multiple devices, this should be the per-replica
batch size. An example of this is sharded data on the TPU, where
batch_size should be set to the total batch size divided by the
number of shards. (Default: None)
placement_strategy: string or None. Device placement strategy used when
creating variables, and various ops. Can be None, 'round_robin', or
'replica_round_robin'. 'round_robin' supports round-robin placement of
various ops on lists of provided devices. 'replica_round_robin' does
something similar but over shards/replicas instead, and only works
in certain 'replicated' contexts (e.g. TPUEstimator). The details of
the different placement strategies are controlled by additional
keyword arguments that can be passed to this class, and which are
described in the different placement mixin classes in placement.py.
(Default: None)
compute_params_stats: Bool. If True, we compute the first order version
of the statistics computed to estimate the Fisher/GGN. These
correspond to the `variables` method in a one-to-one fashion. They
are available via the `params_stats` property. When estimation_mode
is 'empirical', this will correspond to the standard parameter
gradient on the loss. (Default: False)
adapt_damping: `Boolean`. If True we adapt the damping according to the
Levenberg-Marquardt rule described in Section 6.5 of the original
K-FAC paper. The details of this scheme are controlled by various
additional arguments below. Also some of these arguments are extra
pieces of information, such as the loss, needed by the method. Note
that unless using a convenience subclass like
PeriodicInvCovUpdateKfacOpt the damping adaptation op must be
executed by the user (like the cov and inv ops). This op is returned
by the maybe_pre_update_adapt_damping() method. (Default: False)
update_damping_immediately: Damping adjustment strategy. If True then the
damping is updated in the same optimizer minimize call as
`(step+1) % damping_adaptation_interval == 0`, immediately after the
parameter update is performed. If False then the damping is updated
in the next step. If True then it is assumed that the apply_gradients
op will safely update the model before returning; it is recommended
to only resource variables in this case. (Default: True)
is_chief: `Boolean`, `True` if the worker is chief. (Default: True)
prev_train_batch: Training mini-batch used in the previous step. This
will be used to evaluate loss by calling `loss_fn(prev_train_batch)`
when damping adaptation is used. (Default: None)
loss: `Tensor` the model loss, used as the pre-update loss in adaptive
damping. Also used for the built-in log printing. When using
Distribution Strategy, unlike common Distribution Strategy practice,
this loss tensor should by normalized by the per-replica batch size
and NOT the total batch size. (Default: None)
loss_fn: `function` that takes as input training data tensor and returns
a scalar loss. Only needed if using damping adaptation. When using
Distribution Strategy, unlike common Distribution Strategy practice,
the loss should by normalized by the per-replica batch size and NOT
the total batch size. (Default: None)
min_damping: `float`, Minimum value the damping parameter can take. Note
that the default value of 1e-8 is quite arbitrary, and you may have
to adjust this up or down for your particular problem. If you are
using a non-zero value of l2_reg you *may* be able to set this to
zero. (Default: 1e-8)
l2_reg: `float` or 0D Tensor. Set this value to tell the optimizer what L2
regularization coefficient you are using (if any). Note the
coefficient appears in the regularizer as coeff / 2 * sum(param**2),
as the thing you multiply tf.nn.l2(param) by. This will be essentially
added to the minimum damping, but also included in the qmodel change
computations (used for adjusting the damping) even when
_INCLUDE_DAMPING_IN_QMODEL_CHANGE is False. Note that the user is
still responsible for adding regularization to the loss.
(Default: 0.0)
damping_adaptation_decay: `float`, The `damping` parameter is
multiplied by the `damping_adaptation_decay` every
`damping_adaptation_interval` number of iterations. (Default: 0.99)
damping_adaptation_interval: `int`, Number of steps in between
updating the `damping` parameter. Note that damping is adapted at
the step where (step+1) % damping_adaptation_interval == 0,
(or immediately at the start of the next step by
maybe_pre_update_adapt_damping() if update_damping_immediately is
False). (Default: 5)
use_passed_loss: `Boolean`. If True we use the loss tensor passed in by
the user (via minimze() or compute_gradients() or the set_loss()
method) in damping adaptation scheme, instead of calling loss_fn()
a second time for this. This is more efficient but may not always be
desired. (Default: True)
train_batch: Training mini-batch used in the current step. This
will be used to evaluate loss by calling `loss_fn(train_batch)`
when damping adaptation is used and `use_passed_loss` is False.
(Default: None)
print_logs: `Boolean`. If True, we print some logging info using
tf.print after each iteration. This is done in the method
_maybe_print_logging_info, which we encourage you to modify in order
to add whatever you want. (Default: False)
tf_replicator: A Replicator object or None. If not None, K-FAC will set
itself up to work inside of the provided TF-Replicator object.
(Default: None)
dtype: TF dtype or string representing one. dtype used for scalar
properties (rho, etc). (Default: "float32")
**kwargs: Arguments to be passed to specific placement strategy mixin.
Check `placement.RoundRobinPlacementMixin` for example.
Raises:
ValueError: If the momentum type is unsupported.
ValueError: If clipping is used with momentum type other than 'regular'.
ValueError: If no losses have been registered with layer_collection.
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
dtype = tf.dtypes.as_dtype(dtype)
self._layers = layer_collection
self._colocate_gradients_with_ops = colocate_gradients_with_ops
momentum_type = momentum_type.lower()
legal_momentum_types = ["regular", "adam", "qmodel", "qmodel_fixedmu"]
if momentum_type not in legal_momentum_types:
raise ValueError("Unsupported momentum type {}. Must be one of {}."
.format(momentum_type, legal_momentum_types))
if momentum_type not in ["regular", "adam"] and norm_constraint is not None:
raise ValueError("Update clipping is only supported with momentum "
"type 'regular' and 'adam'.")
if momentum_type == "qmodel" and momentum is not None:
raise ValueError("Momentum must be None if using a momentum_type "
"'qmodel'.")
self._momentum_type = momentum_type
self._momentum = momentum
self._use_weight_decay = use_weight_decay
self._weight_decay_coeff = weight_decay_coeff
self._norm_constraint = norm_constraint
self._batch_size = batch_size
self._placement_strategy = placement_strategy
# Damping adaptation parameters
self._adapt_damping = adapt_damping
if self._adapt_damping:
with tf.variable_scope(name):
self._damping = tf.get_variable(
"damping", initializer=damping, trainable=False,
use_resource=True)
else:
self._damping = damping
self._update_damping_immediately = update_damping_immediately
self._is_chief = is_chief
self._prev_train_batch = prev_train_batch
self._loss_tensor = loss
self._loss_fn = loss_fn
self._damping_adaptation_decay = damping_adaptation_decay
self._damping_adaptation_interval = damping_adaptation_interval
self._omega = (
self._damping_adaptation_decay**self._damping_adaptation_interval)
self._min_damping = min_damping
self._use_passed_loss = use_passed_loss
if not use_passed_loss and train_batch is None:
raise ValueError("Must pass in train_batch if used_passed_loss is false.")
self._train_batch = train_batch
self._print_logs = print_logs
self._l2_reg = l2_reg
if self._momentum_type.startswith("qmodel"):
if learning_rate is not None:
raise ValueError("'learning_rate' must be set to None if using one of "
"the 'qmodel' momentum types.")
if qmodel_update_rescale is not None:
learning_rate = qmodel_update_rescale
else:
learning_rate = 1.0
else:
if learning_rate is None:
raise ValueError("'learning_rate' must *not* be set to None unless "
"using one of the 'qmodel' momentum types.")
if qmodel_update_rescale is not None:
raise ValueError("'qmodel_update_rescale' must be None unless using "
"one of the 'qmodel' momentum types.")
self._qmodel_update_rescale = qmodel_update_rescale
with tf.variable_scope(name):
nan_init = lambda: tf.constant(float("nan"), dtype=dtype)
# We store rho only for possible logging purposes.
self._rho = tf.get_variable(
"rho", initializer=nan_init, dtype=dtype,
trainable=False, use_resource=True)
self._prev_loss = tf.get_variable(
"prev_loss", initializer=nan_init, dtype=dtype,
trainable=False, use_resource=True)
self._qmodel_learning_rate = tf.get_variable(
"qmodel_learning_rate", initializer=nan_init, dtype=dtype,
trainable=False, use_resource=True)
self._qmodel_momentum = tf.get_variable(
"qmodel_momentum", initializer=nan_init, dtype=dtype,
trainable=False, use_resource=True)
self._qmodel_change = tf.get_variable(
"qmodel_change", initializer=nan_init, dtype=dtype,
trainable=False, use_resource=True)
self._counter = tf.get_variable(
"counter", dtype=tf.int64, shape=(), trainable=False,
initializer=tf.zeros_initializer, use_resource=True,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
variables = var_list or tf.trainable_variables()
if tf_replicator is not None or tf.distribute.has_strategy():
def _get_sanitized_name(var_name):
return re.sub(r"replica_\d+_", "", var_name)
# This tells K-FAC's libraries that we are using TF-Replicator with this
# particular Replicator object.
utils.set_global_constants(tf_replicator=tf_replicator)
# We need to sanitize the names of the variables that K-FAC creates
# so they are the same between replicas.
ff.set_global_constants(get_sanitized_name_fn=_get_sanitized_name)
self._fisher_est = est.make_fisher_estimator(
placement_strategy=placement_strategy,
variables=variables,
cov_ema_decay=cov_ema_decay,
damping=self._damping,
layer_collection=self.layers,
exps=(-1,),
estimation_mode=estimation_mode,
colocate_gradients_with_ops=self._colocate_gradients_with_ops,
compute_params_stats=compute_params_stats,
batch_size=batch_size,
**kwargs)
super(KfacOptimizer, self).__init__(learning_rate, name=name)