def __init__()

in kfac/python/keras/optimizers.py [0:0]


  def __init__(self,  # pylint: disable=invalid-name
               _sentinel=None,
               learning_rate=None,
               damping=None,
               model=None,
               loss=None,
               loss_weights=None,
               fisher_approx=None,
               layer_collection=None,
               adaptive=False,
               train_batch=None,
               name=None,
               seed=None,
               **kfac_kwargs):
    """Construct a new KFAC optimizer.

    If you construct this Optimizer without a model with a loss, model and loss,
    or a layer_collection, you must call register_layers before using the
    optimizer.

    If you use adaptive, adapt_damping, or qmodel_momentum, this class will set
    up the required loss functions and tensors. You must pass the train_batch
    tensors as a tuple (x, y). If the batch_size cannot be inferred from the
    train_batch[0] tensor, you pass in the batch_size in the constructor. You
    may not use numpy arrays as input when using the adaptive mode. If you do
    not use minimize, you must also provide the loss_tensor.

    When using Distribution Strategy, K-FAC expects a loss tensor that is
    normalized only by the per-replica batch size, and not the total batch size,
    unlike what is commonly recommended. This means you cannot use K-FAC with
    a Distribution Strategy and model.fit at the same time, since model.fit
    does this scaling for you. Instead, use a custom training loop with
    Distribution Strategy (there are examples in the Github repo).

    Args:
      _sentinel: Used to prevent positional parameters. Internal, do not use.
      learning_rate: float or 0D Tensor. Required if not using adapt_damping.
        Refer to kfac.KfacOptimizer for a detailed description.
      damping: Required. float or 0D Tensor. Refer to kfac.KfacOptimizer for a
        detailed description.
      model: Keras model which this class will optimize. Currently, dense, Conv
        1D/2D, and embedding are supported as trainable layers.
      loss: Keras (normal or serialized) loss function. Could be a list or a
        dictionary mapping layer names to (normal or serialized) loss functions.
        Currently, sparse/normal categorical/binary cross entropy and MSE are
        supported.
      loss_weights: An optional list of coefficients or a dictionary mapping
        layer names to the coefficient for each loss functions. If it is a list,
        there must be a the same number of coefficients as loss functions. If
        it is a dictionary and a coefficient is not given for a loss function,
        a coefficient of 1.0 will be used.
      fisher_approx: An optional list of approximations or a dictionary mapping
        layer name/class to fisher approximation type. If it is a list, there
        must be the same number of approximations as there are layers with
        trainable parameters. For each layer, the approximation is determined as
        follows. If fisher_approx is a dictionary, first we check if the name is
        in the dict, if it isn't found the layer class is checked, if it isn't
        found the default is used. When fisher_approx is a list, the order of
        the approximations must match the order of the layers with trainable
        parameters given by model.layers. None is a valid dict/list entry and
        indicates to use the default approximation for that layer.
      layer_collection: Only use this argument when you have an unsupported
        model architecture and so manually register the layers. Refer to
        kfac.KfacOptimizer for a detailed description.
      adaptive: Whether this optimizer is in adaptive mode or not. In adaptive
        mode, we set momentum_type='qmodel' and adapt_damping=True, so you must
        provide the damping (used as the initial value). learning_rate and
        momentum must be None. You must provide a train_batch and potentially
        a batch_size if we cannot infer the batch_size from the train_batch.
      train_batch: A tuple (input, label). The input must be a tensor or a list
        of tensors that you can call the model on. The label must be a tensor
        or list of tensors compatible with the loss_fn. See utils.get_loss_fn
        for the standard loss_fn we create, or you can provide a custom loss_fn.
      name: Optional name for operations created when applying gradients.
        Defaults to "kfac".
      seed: Optional integer specifying the TensorFlow random seed. To get
        deterministic behaviour, the seed needs to be set because the targets
        are sampled to approximate the fisher.
      **kfac_kwargs: Additional arguments to be passed to
        kfac.PeriodicInvCovUpdateKfacOpt (and then to kfac.KfacOptimizer). Note
        the "loss" argument for kfac.KfacOptimizer should be passed as
        "loss_tensor".

    Raises:
      ValueError: If clipvalue or clipnorm arguments are used.
      ValueError: If positional arguments are used (or _sentinel is used).
      ValueError: If damping is not provided.
      ValueError: If learning_rate or momentum are set with adaptive=True.
    """
    if tf.executing_eagerly():
      logging.warn('Eager mode appears to be enabled. Kfac is untested in '
                   'eager mode.')
    if _sentinel:
      raise ValueError('Do not pass positional arguments, only use keyword '
                       'arguments.')
    if damping is None:
      raise ValueError('Please provide a value for damping.')

    if 'clipvalue' in kfac_kwargs:
      raise ValueError('Argument "clipvalue" is not support.')
    if 'clipnorm' in kfac_kwargs:
      raise ValueError('Argument "clipnorm" is not supported. Use '
                       '"norm_constraint" instead.')

    super(Kfac, self).__init__(name=name)

    kfac_kwargs.update({'name': self._name,
                        'learning_rate': learning_rate,
                        'damping': damping})

    _configure_kfac_kwargs_for_adaptive(kfac_kwargs, adaptive)

    self._optimizer = None
    self._layer_collection = None
    self._model = model
    self._loss = loss
    self._have_tracked_vars = False
    self._tf_var_scope = self._name + '/tf_vars'
    # We use _kfac_kwargs and _config in various parts in the code below.
    # _kfac_kwargs is checked when we want to know only what the user passed.
    # _config is used when we want user selections with the default kwargs as a
    # fallback.
    self._kfac_kwargs = kfac_kwargs
    self._layer_collection_kwargs = {
        'loss_weights': loss_weights,
        'fisher_approx': utils.serialize_fisher_approx(fisher_approx),
        'seed': seed,
    }
    self._config = _DEFAULT_KWARGS.copy()
    self._config.update(kfac_kwargs)
    self._config.update(self._layer_collection_kwargs)
    self._config['loss'] = utils.serialize_loss(loss)

    if 'loss_tensor' in self._kfac_kwargs:
      self._kfac_kwargs['loss'] = self._kfac_kwargs.pop('loss_tensor')

    self._mutable_hypers = _MUTABLE_HYPER_PARAMS.copy()
    if self._config['adapt_damping']:
      self._mutable_hypers.remove('damping')
    if self._config['momentum_type'].lower().startswith('qmodel'):
      self._mutable_hypers -= {'learning_rate', 'momentum'}
    for hp in self._mutable_hypers.copy():
      if self._config[hp] is None:
        self._mutable_hypers.remove(hp)
      else:
        self._set_hyper(hp, self._config[hp])

    if layer_collection:
      self.register_layers(layer_collection=layer_collection)
    if train_batch and self._kfac_kwargs.get('adapt_damping', False):
      self.register_train_batch(train_batch=train_batch)