in kfac/python/ops/layer_collection.py [0:0]
def register_conv2d(self,
params,
strides,
padding,
inputs,
outputs,
data_format=None,
dilations=None,
approx=None,
reuse=VARIABLE_SCOPE,
sub_sample_inputs=None,
sub_sample_patches=None,
patch_mask=None):
"""Registers a call to tf.nn.conv2d().
Args:
params: Variable or 2-tuple of variables corresponding to weight and
bias parameters of this layer. Weight matrix should have shape
[kernel_height, kernel_width, in_channels, out_channels]. Bias should
have shape [out_channels].
strides: List of 4 ints. Strides for convolution kernel.
padding: string. see tf.nn.conv2d for valid values.
inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
to layer.
outputs: Tensor of shape [batch_size, height, width, out_channels].
Output produced by layer.
data_format: str or None. Format of data. If None, this should default
to 'NWHC'. (Default: None)
dilations: List of 4 ints. Dilations along each dimension.
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
sub_sample_inputs: `bool`. If True, then subsample the inputs from which
the image patches are extracted. (Default: None)
sub_sample_patches: `bool`, If `True` then subsample the extracted
patches. (Default: None)
patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
or None. If not None this is multiplied against the extracted patches
Tensor (broadcasting along the batch dimension) before statistics are
computed. This can (and probably should) be used if the filter bank
matrix is masked in a way that is homogenous across the output channels.
(Other masking patterns have no direct support.) Currently only works
with the approx="kron" or "diagonal". (Default: None)
Raises:
ValueError: For improper value to 'approx'.
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
assert data_format in [None, "NHWC"] # We don't support NCHW right now
block_type, approx = self._get_block_type(
params, approx, self.default_conv2d_approximation,
self._conv2d_approx_to_block_types)
# It feels bad to pass in configuration that has to do with the internal
# implementation. And then we can't use the same constructor for both
# anymore and are thus forced to use this ugly if-statement.
# TODO(b/74793309): Clean this up?
if approx == APPROX_KRONECKER_NAME:
block = self._register_block(
params,
block_type(
layer_collection=self,
params=params,
padding=padding,
strides=strides,
data_format=data_format,
dilation_rate=dilations,
extract_patches_fn="extract_image_patches",
sub_sample_inputs=sub_sample_inputs,
sub_sample_patches=sub_sample_patches,
use_sua_approx_for_input_factor=False,
patch_mask=patch_mask),
reuse=reuse)
elif approx == APPROX_DIAGONAL_NAME:
assert strides[0] == strides[-1] == 1
block = self._register_block(
params,
block_type(
layer_collection=self,
params=params,
padding=padding,
strides=strides,
dilations=dilations,
data_format=data_format,
patch_mask=patch_mask),
reuse=reuse)
elif approx == APPROX_KRONECKER_SUA_NAME:
block = self._register_block(
params,
block_type(
layer_collection=self,
params=params,
padding=padding,
use_sua_approx_for_input_factor=True),
reuse=reuse)
else:
raise NotImplementedError(approx)
block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)