def _compute_fans()

in training/flax/distil_whisper/layers.py [0:0]


def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
    """Inlined JAX `nn.initializer._compute_fans`."""
    if isinstance(in_axis, int):
        in_size = shape[in_axis]
    else:
        in_size = int(np.prod([shape[i] for i in in_axis]))
    if isinstance(out_axis, int):
        out_size = shape[out_axis]
    else:
        out_size = int(np.prod([shape[i] for i in out_axis]))
    receptive_field_size = shape.total / in_size / out_size
    fan_in = in_size * receptive_field_size
    fan_out = out_size * receptive_field_size
    return fan_in, fan_out