def get_mask()

in lib/masked_attention.py [0:0]


def get_mask(first_b11: th.Tensor, state_mask: th.Tensor, t: int, T: int, maxlen: int, heads: int, device) -> th.Tensor:
    """Returns a band diagonal mask that respects masking past states (columns 0:T-t inclusive)
        if first_b11 is True. See get_band_diagonal_mask for how the base mask is computed.
        This function takes that mask and first zeros out any past context if first_b11 is True.

        Say our context is in chunks of length t (so here T = 4t). We see that in the second batch we recieved first=True
        context     t t t t
        first       F T F F
        Now, given this the mask should mask out anything prior to T < t; however since we don't have access to the past first_b11's
        we need to keep a state of the mask at those past timesteps. This is what state_mask is.

        In particular state_mask is a [b, t, T - t] mask matrix that contains the mask for the past T - t frames.

    Args: (See get_band_diagonal_mask for remaining args)
        first_b11: boolean tensor with shape [batchsize, 1, 1] indicating if the first timestep for each batch element had first=True
        state_mask: mask tensor of shape [b, t, T - t]
        t: number of mask rows (presumably number of frames for which we take gradient)
        T: number of mask columns (t + the number of past frames we keep in context)
        maxlen: actual context length
        heads: number of attention heads
        device: torch device

    Returns:
        m_btT: Boolean mask of shape (batchsize * heads, t, T)
        state_mask: updated state_mask
    """
    b = first_b11.shape[0]

    if state_mask is None:
        state_mask = th.zeros((b, 1, T - t), dtype=bool, device=device)

    m_btT = get_band_diagonal_mask(t, T, maxlen, b, device).clone()  # Should be shape B, t, T
    not_first = ~first_b11.to(device=device)
    m_btT[:, :, :-t] &= not_first  # Zero out anything in the past if first is true
    m_btT[:, :, :-t] &= state_mask
    m_bhtT = m_btT[:, None].repeat_interleave(heads, dim=1)
    m_btT = m_bhtT.reshape((b * heads), t, T)

    # Update state_mask such that it reflects the most recent first
    state_mask = th.cat(
        [
            state_mask[:, :, t:] & not_first,
            th.ones((b, 1, min(t, T - t)), dtype=bool, device=device),
        ],
        dim=-1,
    )

    return m_btT, state_mask