in training/flax/distil_whisper/layers.py [0:0]
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
""" "Canonicalizes conv padding to a jax.lax supported format."""
if isinstance(padding, str):
return padding
if isinstance(padding, int):
return [(padding, padding)] * rank
if isinstance(padding, Sequence) and len(padding) == rank:
new_pad = []
for p in padding:
if isinstance(p, int):
new_pad.append((p, p))
elif isinstance(p, tuple) and len(p) == 2:
new_pad.append(p)
else:
break
if len(new_pad) == rank:
return new_pad
raise ValueError(
f"Invalid padding format: {padding}, should be str, int,"
f" or a sequence of len {rank} where each element is an"
" int or pair of ints."
)