in training/flax/distil_whisper/layers.py [0:0]
def __call__(self, inputs: Array) -> Array:
"""Applies a (potentially unshared) convolution to the inputs.
Args:
inputs: input data with dimensions (*batch_dims, spatial_dims...,
features). This is the channels-last convention, i.e. NHWC for a 2d
convolution and NDHWC for a 3D convolution. Note: this is different from
the input convention used by `lax.conv_general_dilated`, which puts the
spatial dimensions last.
Note: If the input has more than 1 batch dimension, all batch dimensions
are flattened into a single dimension for the convolution and restored
before returning. In some cases directly vmap'ing the layer may yield
better performance than this default flattening approach. If the input
lacks a batch dimension it will be added for the convolution and removed
n return, an allowance made to enable writing single-example code.
Returns:
The convolved data.
"""
if isinstance(self.kernel_size, int):
raise TypeError(
"Expected Conv kernel_size to be a"
" tuple/list of integers (eg.: [3, 3]) but got"
f" {self.kernel_size}."
)
else:
kernel_size = tuple(self.kernel_size)
def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
if x is None:
# backward compatibility with using None as sentinel for
# broadcast 1
x = 1
if isinstance(x, int):
return (x,) * len(kernel_size)
return tuple(x)
# Combine all input batch dimensions into a single leading batch axis.
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
if num_batch_dimensions != 1:
input_batch_shape = inputs.shape[:num_batch_dimensions]
total_batch_size = int(np.prod(input_batch_shape))
flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
inputs = jnp.reshape(inputs, flat_input_shape)
# self.strides or (1,) * (inputs.ndim - 2)
strides = maybe_broadcast(self.strides)
input_dilation = maybe_broadcast(self.input_dilation)
kernel_dilation = maybe_broadcast(self.kernel_dilation)
padding_lax = canonicalize_padding(self.padding, len(kernel_size))
if padding_lax == "CIRCULAR":
kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
zero_pad: List[Tuple[int, int]] = [(0, 0)]
pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
inputs = jnp.pad(inputs, pads, mode="wrap")
padding_lax = "VALID"
elif padding_lax == "CAUSAL":
if len(kernel_size) != 1:
raise ValueError("Causal padding is only implemented for 1D convolutions.")
left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
pads = [(0, 0), (left_pad, 0), (0, 0)]
inputs = jnp.pad(inputs, pads)
padding_lax = "VALID"
dimension_numbers = _conv_dimension_numbers(inputs.shape)
in_features = jnp.shape(inputs)[-1]
if self.shared_weights:
# One shared convolutional kernel for all pixels in the output.
assert in_features % self.feature_group_count == 0
kernel_shape = kernel_size + (
in_features // self.feature_group_count,
self.features,
)
else:
if self.feature_group_count != 1:
raise NotImplementedError(
"`lax.conv_general_dilated_local` does not support "
f"`feature_group_count != 1`, got `{self.feature_group_count}`."
)
# Need to know the spatial output shape of a standard convolution to
# create the unshared convolution kernel.
conv_output_shape = jax.eval_shape(
lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
lhs=lhs,
rhs=rhs,
window_strides=strides,
padding=padding_lax,
dimension_numbers=dimension_numbers,
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
),
inputs,
jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
).shape
# One (unshared) convolutional kernel per each pixel in the output.
kernel_shape = conv_output_shape[1:-1] + (
np.prod(kernel_size) * in_features,
self.features,
)
if self.mask is not None and self.mask.shape != kernel_shape:
raise ValueError(
"Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
)
kernel = param_with_axes(
"kernel",
self.kernel_init,
kernel_shape,
self.params_dtype,
axes=self.kernel_axes,
)
if self.mask is not None:
kernel *= self.mask
if self.use_bias:
if self.shared_weights:
# One bias weight per output channel, shared between pixels.
bias_shape = (self.features,)
else:
# One bias weight per output entry, unshared betwen pixels.
bias_shape = conv_output_shape[1:]
bias = param_with_axes(
"bias",
self.bias_init,
bias_shape,
self.params_dtype,
axes=(self.kernel_axes[-1],),
)
else:
bias = None
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
if self.shared_weights:
y = self.conv_general_dilated(
inputs,
kernel,
strides,
padding_lax,
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
dimension_numbers=dimension_numbers,
feature_group_count=self.feature_group_count,
precision=self.precision,
)
else:
y = lax.conv_general_dilated_local(
lhs=inputs,
rhs=kernel,
window_strides=strides,
padding=padding_lax,
filter_shape=kernel_size,
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
dimension_numbers=dimension_numbers,
precision=self.precision,
)
if self.use_bias:
bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
y += bias
if num_batch_dimensions != 1:
output_shape = input_batch_shape + y.shape[1:]
y = jnp.reshape(y, output_shape)
return y