def combine_masks()

in training/flax/distil_whisper/layers.py [0:0]


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
    """Combine attention masks.

    Args:
      *masks: set of attention mask arguments to combine, some can be None.
      dtype: final mask dtype

    Returns:
      Combined mask, reduced by logical and, returns None if no masks given.
    """
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(
        (x.ndim == masks[0].ndim for x in masks)
    ), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = jnp.logical_and(mask, other_mask)
    return mask.astype(dtype)