def num_conv_locations()

in kfac/python/ops/utils.py [0:0]


def num_conv_locations(input_shape, filter_shape, strides, padding):
  """Returns the number of spatial locations a conv kernel is applied to.

  Args:
    input_shape: List of ints representing shape of inputs to
      tf.nn.convolution().
    filter_shape: List of ints representing shape of filter to
      tf.nn.convolution().
    strides: List of ints representing strides along spatial dimensions as
      passed in to tf.nn.convolution().
    padding: string representing the padding method, either 'VALID' or 'SAME'.

  Returns:
    A scalar |T| denoting the number of spatial locations for the Conv layer.

  Raises:
    ValueError: If input_shape, filter_shape don't represent a 1-D or 2-D
      convolution.
  """
  if len(input_shape) != 4 and len(input_shape) != 3:
    raise ValueError("input_shape must be length 4, corresponding to a Conv2D,"
                     " or length 3, corresponding to a Conv1D.")
  if len(input_shape) != len(filter_shape):
    raise ValueError("Inconsistent number of dimensions between input and "
                     "filter for convolution")

  if strides is None:
    if len(input_shape) == 4:
      strides = [1, 1, 1, 1]
    else:
      strides = [1, 1, 1]

  # Use negative integer division to implement 'rounding up'.
  # Formula for convolution shape taken from:
  # http://machinelearninguru.com/computer_vision/basics/convolution/convolution_layer.html
  if len(input_shape) == 3:
    if padding is not None and padding.lower() == "valid":
      out_width = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1])
    else:
      out_width = -(-input_shape[1] // strides[1])

    return out_width
  else:
    if padding is not None and padding.lower() == "valid":
      out_height = -(-(input_shape[1] - filter_shape[0] + 1) // strides[1])
      out_width = -(-(input_shape[2] - filter_shape[1] + 1) // strides[2])
    else:
      out_height = -(-input_shape[1] // strides[1])
      out_width = -(-input_shape[2] // strides[2])

    return out_height * out_width