in training/flax/distil_whisper/layers.py [0:0]
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
"""Inlined JAX `nn.initializer._compute_fans`."""
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:
in_size = int(np.prod([shape[i] for i in in_axis]))
if isinstance(out_axis, int):
out_size = shape[out_axis]
else:
out_size = int(np.prod([shape[i] for i in out_axis]))
receptive_field_size = shape.total / in_size / out_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size
return fan_in, fan_out