in convit.py [0:0]
def get_attention_map(self, x, return_map = False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_map = (q @ k.transpose(-2, -1)) * self.scale
attn_map = attn_map.softmax(dim=-1).mean(0)
img_size = int(N**.5)
ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
indx = ind.repeat(img_size,img_size)
indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
indd = indx**2 + indy**2
distances = indd**.5
distances = distances.to('cuda')
dist = torch.einsum('nm,hnm->h', (distances, attn_map))
dist /= N
if return_map:
return dist, attn_map
else:
return dist