def _compute_new_cov()

in kfac/python/ops/fisher_factors.py [0:0]


  def _compute_new_cov(self, source, tower):
    assert source == 0

    inputs = self._inputs[tower]
    if self._sub_sample_inputs:

      batch_size = inputs.shape.as_list()[0]
      if batch_size is None:
        # dynamic case:
        batch_size = utils.get_shape(inputs)[0]
        # computes: int(math.ceil(batch_size
        #                               * _INPUTS_TO_EXTRACT_PATCHES_FACTOR))
        new_size = tf.cast(
            tf.ceil(tf.multiply(tf.cast(batch_size, dtype=tf.float32),
                                _INPUTS_TO_EXTRACT_PATCHES_FACTOR)),
            dtype=utils.preferred_int_dtype())
      else:
        # static case:
        new_size = int(math.ceil(batch_size
                                 * _INPUTS_TO_EXTRACT_PATCHES_FACTOR))

      inputs = _random_tensor_gather(inputs, new_size)

    # TODO(b/64144716): there is potential here for a big savings in terms of
    # memory use.
    if _USE_PATCHES_SECOND_MOMENT_OP:
      raise NotImplementedError  # patches op is not available outside of Google,
                                 # sorry! You'll need to turn it off to proceed.
    else:
      if self._extract_patches_fn in [None, "extract_convolution_patches"]:
        patches = utils.extract_convolution_patches(
            inputs,
            self._filter_shape,
            padding=self._padding,
            strides=self._strides,
            dilation_rate=self._dilation_rate,
            data_format=self._data_format)

      elif self._extract_patches_fn == "extract_image_patches":
        assert inputs.shape.ndims == 4
        assert len(self._filter_shape) == 4
        assert len(self._strides) == 4, self._strides
        if self._dilation_rate is None:
          rates = [1, 1, 1, 1]
        else:
          rates = self._dilation_rate
          assert len(rates) == 4
          assert rates[0] == rates[-1] == 1
        patches = tf.extract_image_patches(
            inputs,
            ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
            strides=self._strides,
            rates=rates,
            padding=self._padding)

      elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
        assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
        assert self._filter_shape[0] == self._filter_shape[1] == 1
        patches = utils.extract_pointwise_conv2d_patches(
            inputs, self._filter_shape, data_format=None)

      else:
        raise NotImplementedError(self._extract_patches_fn)

      if self._patch_mask is not None:
        assert self._patch_mask.shape == self._filter_shape[0:-1]
        # This should work as intended due to broadcasting.
        patches *= tf.reshape(self._patch_mask, [-1])

      flatten_size = np.prod(self._filter_shape[0:-1])
      # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
      # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
      # where M = minibatch size, |T| = number of spatial locations,
      # |Delta| = number of spatial offsets, and J = number of input maps
      # for convolutional layer l.
      patches_flat = tf.reshape(patches, [-1, flatten_size])
      # We append a homogenous coordinate to patches_flat if the layer has
      # bias parameters. This gives us [[A_l]]_H from the paper.
      if self._sub_sample_patches:
        patches_flat = _subsample_patches(patches_flat)

      if self._has_bias:
        patches_flat = append_homog(patches_flat)
      # We call compute_cov without passing in a normalizer. compute_cov uses
      # the first dimension of patches_flat i.e. M|T| as the normalizer by
      # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
      # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
      # the paper but has a different scale here for consistency with
      # ConvOutputKroneckerFactor.
      # (Tilde omitted over A for clarity.)
      return compute_cov(patches_flat)