in training/flax/distil_whisper/layers.py [0:0]
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
Args:
*masks: set of attention bias arguments to combine, some can be None.
Returns:
Combined mask, reduced by summation, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(
(x.ndim == masks[0].ndim for x in masks)
), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
return mask