in kfac/python/keras/optimizers.py [0:0]
def __init__(self, # pylint: disable=invalid-name
_sentinel=None,
learning_rate=None,
damping=None,
model=None,
loss=None,
loss_weights=None,
fisher_approx=None,
layer_collection=None,
adaptive=False,
train_batch=None,
name=None,
seed=None,
**kfac_kwargs):
"""Construct a new KFAC optimizer.
If you construct this Optimizer without a model with a loss, model and loss,
or a layer_collection, you must call register_layers before using the
optimizer.
If you use adaptive, adapt_damping, or qmodel_momentum, this class will set
up the required loss functions and tensors. You must pass the train_batch
tensors as a tuple (x, y). If the batch_size cannot be inferred from the
train_batch[0] tensor, you pass in the batch_size in the constructor. You
may not use numpy arrays as input when using the adaptive mode. If you do
not use minimize, you must also provide the loss_tensor.
When using Distribution Strategy, K-FAC expects a loss tensor that is
normalized only by the per-replica batch size, and not the total batch size,
unlike what is commonly recommended. This means you cannot use K-FAC with
a Distribution Strategy and model.fit at the same time, since model.fit
does this scaling for you. Instead, use a custom training loop with
Distribution Strategy (there are examples in the Github repo).
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.
learning_rate: float or 0D Tensor. Required if not using adapt_damping.
Refer to kfac.KfacOptimizer for a detailed description.
damping: Required. float or 0D Tensor. Refer to kfac.KfacOptimizer for a
detailed description.
model: Keras model which this class will optimize. Currently, dense, Conv
1D/2D, and embedding are supported as trainable layers.
loss: Keras (normal or serialized) loss function. Could be a list or a
dictionary mapping layer names to (normal or serialized) loss functions.
Currently, sparse/normal categorical/binary cross entropy and MSE are
supported.
loss_weights: An optional list of coefficients or a dictionary mapping
layer names to the coefficient for each loss functions. If it is a list,
there must be a the same number of coefficients as loss functions. If
it is a dictionary and a coefficient is not given for a loss function,
a coefficient of 1.0 will be used.
fisher_approx: An optional list of approximations or a dictionary mapping
layer name/class to fisher approximation type. If it is a list, there
must be the same number of approximations as there are layers with
trainable parameters. For each layer, the approximation is determined as
follows. If fisher_approx is a dictionary, first we check if the name is
in the dict, if it isn't found the layer class is checked, if it isn't
found the default is used. When fisher_approx is a list, the order of
the approximations must match the order of the layers with trainable
parameters given by model.layers. None is a valid dict/list entry and
indicates to use the default approximation for that layer.
layer_collection: Only use this argument when you have an unsupported
model architecture and so manually register the layers. Refer to
kfac.KfacOptimizer for a detailed description.
adaptive: Whether this optimizer is in adaptive mode or not. In adaptive
mode, we set momentum_type='qmodel' and adapt_damping=True, so you must
provide the damping (used as the initial value). learning_rate and
momentum must be None. You must provide a train_batch and potentially
a batch_size if we cannot infer the batch_size from the train_batch.
train_batch: A tuple (input, label). The input must be a tensor or a list
of tensors that you can call the model on. The label must be a tensor
or list of tensors compatible with the loss_fn. See utils.get_loss_fn
for the standard loss_fn we create, or you can provide a custom loss_fn.
name: Optional name for operations created when applying gradients.
Defaults to "kfac".
seed: Optional integer specifying the TensorFlow random seed. To get
deterministic behaviour, the seed needs to be set because the targets
are sampled to approximate the fisher.
**kfac_kwargs: Additional arguments to be passed to
kfac.PeriodicInvCovUpdateKfacOpt (and then to kfac.KfacOptimizer). Note
the "loss" argument for kfac.KfacOptimizer should be passed as
"loss_tensor".
Raises:
ValueError: If clipvalue or clipnorm arguments are used.
ValueError: If positional arguments are used (or _sentinel is used).
ValueError: If damping is not provided.
ValueError: If learning_rate or momentum are set with adaptive=True.
"""
if tf.executing_eagerly():
logging.warn('Eager mode appears to be enabled. Kfac is untested in '
'eager mode.')
if _sentinel:
raise ValueError('Do not pass positional arguments, only use keyword '
'arguments.')
if damping is None:
raise ValueError('Please provide a value for damping.')
if 'clipvalue' in kfac_kwargs:
raise ValueError('Argument "clipvalue" is not support.')
if 'clipnorm' in kfac_kwargs:
raise ValueError('Argument "clipnorm" is not supported. Use '
'"norm_constraint" instead.')
super(Kfac, self).__init__(name=name)
kfac_kwargs.update({'name': self._name,
'learning_rate': learning_rate,
'damping': damping})
_configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive)
self._optimizer = None
self._layer_collection = None
self._model = model
self._loss = loss
self._have_tracked_vars = False
self._tf_var_scope = self._name + '/tf_vars'
# We use _kfac_kwargs and _config in various parts in the code below.
# _kfac_kwargs is checked when we want to know only what the user passed.
# _config is used when we want user selections with the default kwargs as a
# fallback.
self._kfac_kwargs = kfac_kwargs
self._layer_collection_kwargs = {
'loss_weights': loss_weights,
'fisher_approx': utils.serialize_fisher_approx(fisher_approx),
'seed': seed,
}
self._config = _DEFAULT_KWARGS.copy()
self._config.update(kfac_kwargs)
self._config.update(self._layer_collection_kwargs)
self._config['loss'] = utils.serialize_loss(loss)
if 'loss_tensor' in self._kfac_kwargs:
self._kfac_kwargs['loss'] = self._kfac_kwargs.pop('loss_tensor')
self._mutable_hypers = _MUTABLE_HYPER_PARAMS.copy()
if self._config['adapt_damping']:
self._mutable_hypers.remove('damping')
if self._config['momentum_type'].lower().startswith('qmodel'):
self._mutable_hypers -= {'learning_rate', 'momentum'}
for hp in self._mutable_hypers.copy():
if self._config[hp] is None:
self._mutable_hypers.remove(hp)
else:
self._set_hyper(hp, self._config[hp])
if layer_collection:
self.register_layers(layer_collection=layer_collection)
if train_batch and self._kfac_kwargs.get('adapt_damping', False):
self.register_train_batch(train_batch=train_batch)