def instantiate_inv_variables()

in kfac/python/ops/fisher_factors.py [0:0]


  def instantiate_inv_variables(self):
    super(FullyConnectedMultiKF, self).instantiate_inv_variables()

    for damping_id in self._option1quants_registrations:
      damping_func = self._damping_funcs_by_id[damping_id]
      damping_string = graph_func_to_string(damping_func)
      # It's questionable as to whether we should initialize with stuff like
      # this at all.  Ideally these values should never be used until they are
      # updated at least once.
      with tf.variable_scope(self._var_scope):
        Lmat = tf.get_variable(  # pylint: disable=invalid-name
            "Lmat_damp{}".format(damping_string),
            initializer=inverse_initializer,
            shape=self._cov_shape,
            trainable=False,
            dtype=self._dtype,
            use_resource=True)
        psi = tf.get_variable(
            "psi_damp{}".format(damping_string),
            initializer=tf.ones_initializer(),
            shape=self._vec_shape,
            trainable=False,
            dtype=self._dtype,
            use_resource=True)

      assert damping_id not in self._option1quants_by_damping
      self._option1quants_by_damping[damping_id] = (Lmat, psi)

    for damping_id in self._option2quants_registrations:
      damping_func = self._damping_funcs_by_id[damping_id]
      damping_string = graph_func_to_string(damping_func)
      # It's questionable as to whether we should initialize with stuff like
      # this at all.  Ideally these values should never be used until they are
      # updated at least once.
      with tf.variable_scope(self._var_scope):
        Pmat = tf.get_variable(  # pylint: disable=invalid-name
            "Lmat_damp{}".format(damping_string),
            initializer=inverse_initializer,
            shape=self._cov_shape,
            trainable=False,
            dtype=self._dtype,
            use_resource=True)
        Kmat = tf.get_variable(  # pylint: disable=invalid-name
            "Kmat_damp{}".format(damping_string),
            initializer=inverse_initializer,
            shape=self._cov_shape,
            trainable=False,
            dtype=self._dtype,
            use_resource=True)
        mu = tf.get_variable(
            "mu_damp{}".format(damping_string),
            initializer=tf.ones_initializer(),
            shape=self._vec_shape,
            trainable=False,
            dtype=self._dtype,
            use_resource=True)

      assert damping_id not in self._option2quants_by_damping
      self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)