def __init__()

in kfac/python/ops/fisher_factors.py [0:0]


  def __init__(self,
               inputs,
               outputs_grads,
               filter_shape,
               strides,
               padding,
               data_format=None,
               dilations=None,
               has_bias=False,
               patch_mask=None):
    """Creates a ConvDiagonalFactor object.

    Args:
      inputs: List of Tensors of shape [batch_size, height, width, in_channels].
        Input activations to this layer.  List index is towers.
      outputs_grads: List of Tensors, each of shape [batch_size,
        height, width, out_channels], which are the gradients of the loss
        with respect to the layer's outputs.  First index is source, second
        index is tower.
      filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
        out_channels). Represents shape of kernel used in this layer.
      strides: The stride size in this layer (1-D Tensor of length 4).
      padding: The padding in this layer (1-D of Tensor length 4).
      data_format: None or str. Format of conv2d inputs.
      dilations: None or tuple of 4 ints.
      has_bias: Python bool. If True, the layer is assumed to have a bias
        parameter in addition to its filter parameter.
      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
        or None. If not None this is multiplied against the extracted patches
        Tensor (broadcasting along the batch dimension) before statistics are
        computed. (Default: None)

    Raises:
      ValueError: If inputs, output_grads, and filter_shape do not agree on
        in_channels or out_channels.
      ValueError: If strides, dilations are not length-4 lists of ints.
      ValueError: If data_format does not put channel last.
    """
    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("Channel must be last.")
    if any(input_.shape.ndims != 4 for input_ in inputs):
      raise ValueError("inputs must be a list of 4-D Tensors.")
    if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
      raise ValueError("inputs and filter_shape must agree on in_channels.")
    for i, outputs_grad in enumerate(outputs_grads):
      if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
        raise ValueError("outputs[%d] must be 4-D Tensor." % i)
      if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
             for output_grad in outputs_grad):
        raise ValueError(
            "outputs[%d] and filter_shape must agree on out_channels." % i)
    if len(strides) != 4:
      raise ValueError("strides must be length-4 list of ints.")
    if dilations is not None and len(dilations) != 4:
      raise ValueError("dilations must be length-4 list of ints.")

    self._inputs = inputs
    self._outputs_grads = outputs_grads
    self._filter_shape = filter_shape
    self._strides = strides
    self._padding = padding
    self._data_format = data_format
    self._dilations = dilations
    self._has_bias = has_bias
    self._patches = None

    self._patch_mask = patch_mask

    super(ConvDiagonalFactor, self).__init__()