in adaptive_span.py [0:0]
def forward(self, attn, normalize=True):
"""mask attention with the right span"""
# batch and head dimensions are merged together, so separate them first
B = attn.size(0) # batch size
M = attn.size(1) # block size
attn = attn.reshape(B // self._nb_heads, self._nb_heads, M, -1)
attn = self._mask(attn)
if normalize:
attn = attn / (attn.sum(-1, keepdim=True) + 1e-8) # normalize so sum is 1
attn = attn.view(B, M, -1)
return attn