in kfac/python/ops/fisher_factors.py [0:0]
def __init__(self,
inputs,
outputs_grads,
filter_shape,
strides,
padding,
data_format=None,
dilations=None,
has_bias=False,
patch_mask=None):
"""Creates a ConvDiagonalFactor object.
Args:
inputs: List of Tensors of shape [batch_size, height, width, in_channels].
Input activations to this layer. List index is towers.
outputs_grads: List of Tensors, each of shape [batch_size,
height, width, out_channels], which are the gradients of the loss
with respect to the layer's outputs. First index is source, second
index is tower.
filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
out_channels). Represents shape of kernel used in this layer.
strides: The stride size in this layer (1-D Tensor of length 4).
padding: The padding in this layer (1-D of Tensor length 4).
data_format: None or str. Format of conv2d inputs.
dilations: None or tuple of 4 ints.
has_bias: Python bool. If True, the layer is assumed to have a bias
parameter in addition to its filter parameter.
patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
or None. If not None this is multiplied against the extracted patches
Tensor (broadcasting along the batch dimension) before statistics are
computed. (Default: None)
Raises:
ValueError: If inputs, output_grads, and filter_shape do not agree on
in_channels or out_channels.
ValueError: If strides, dilations are not length-4 lists of ints.
ValueError: If data_format does not put channel last.
"""
if not utils.is_data_format_channel_last(data_format):
raise ValueError("Channel must be last.")
if any(input_.shape.ndims != 4 for input_ in inputs):
raise ValueError("inputs must be a list of 4-D Tensors.")
if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
raise ValueError("inputs and filter_shape must agree on in_channels.")
for i, outputs_grad in enumerate(outputs_grads):
if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
raise ValueError("outputs[%d] must be 4-D Tensor." % i)
if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
for output_grad in outputs_grad):
raise ValueError(
"outputs[%d] and filter_shape must agree on out_channels." % i)
if len(strides) != 4:
raise ValueError("strides must be length-4 list of ints.")
if dilations is not None and len(dilations) != 4:
raise ValueError("dilations must be length-4 list of ints.")
self._inputs = inputs
self._outputs_grads = outputs_grads
self._filter_shape = filter_shape
self._strides = strides
self._padding = padding
self._data_format = data_format
self._dilations = dilations
self._has_bias = has_bias
self._patches = None
self._patch_mask = patch_mask
super(ConvDiagonalFactor, self).__init__()