def make_full_layout()

in glide_text2im/clip/attention.py [0:0]


def make_full_layout(d: AttentionMask) -> np.ndarray:
    """
    Returns the `context_size x context_size` layout matrix described by `d`. If the layout is dependent on the index of
    the attention head, a `attention_head x context_size x context_size` layout matrix is returned instead.
    """

    if not d.is_head_specific:
        u = np.reshape(d.global_layout, [d.n_query_block, d.n_key_block, 1, 1])
        r = product(range(d.n_query_block), range(d.n_key_block))
        v = np.array([d.block_layout(None, 0, i, j, 0) for i, j in r])
        v = np.reshape(v, [d.n_query_block, d.n_key_block, d.block_size, d.block_size])

        w = u * v
        w = np.transpose(w, [0, 2, 1, 3])
        w = np.reshape(w, [d.query_context_size, d.key_context_size])
        return w
    else:
        if len(d.global_layout.shape) == 2:
            u = np.reshape(d.global_layout, [1, d.n_query_block, d.n_key_block, 1, 1])
            u = np.tile(u, [d.n_head, 1, 1, 1, 1])
        elif len(d.global_layout.shape) == 3:
            u = np.reshape(d.global_layout, [d.n_head, d.n_query_block, d.n_key_block, 1, 1])
        else:
            raise RuntimeError()

        s = product(range(d.n_head), range(d.n_query_block), range(d.n_key_block))
        v = np.array([d.block_layout(None, i, j, k, 0) for i, j, k in s])
        v = np.reshape(v, [d.n_head, d.n_query_block, d.n_key_block, d.block_size, d.block_size])

        w = u * v
        w = np.transpose(w, [0, 1, 3, 2, 4])
        w = np.reshape(w, [d.n_head, d.query_context_size, d.key_context_size])
        return w