in models/spatial/attention.py [0:0]
def forward(self, x, attn_mask=None, rm_nonself_grads=False, attn_multiplier=None):
"""
Args:
x: (T, N, D)
attn_mask: (T, T) added to pre-softmax logits.
"""
T, N, _ = x.shape
q = k = torch.einsum("tbm,mhd->tbhd", x, self.q_weight)
squared_dist = (torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1)
+ torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0)
- 2 * torch.einsum('tbhd,sbhd->tsbh', q, k))
attn_logits = -squared_dist / math.sqrt(self.head_dim)
if attn_mask is not None:
attn_mask = attn_mask[..., None, None]
attn_logits += attn_mask
attn_weights = F.softmax(attn_logits, dim=1) # (T, S, N, H)
attn_weights = update_attn_weights(attn_weights, attn_multiplier)
A = torch.einsum("mhd,nhd->hmn", self.q_weight, self.q_weight) / math.sqrt(self.head_dim)
XA = torch.einsum("tbm,hmn->tbhn", x, A)
PXA = torch.einsum("tsbh,sbhm->tbhm", attn_weights, XA)
if rm_nonself_grads:
# Construct self-only gradient paths wrt keys and queries.
q_detach = q.detach()
k_detach = k.detach()
attn_logits_keyonly = -(torch.einsum('tbhd,tbhd->tbh', q_detach, q_detach).unsqueeze(1)
+ torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0)
- 2 * torch.einsum('tbhd,sbhd->tsbh', q_detach, k)) / math.sqrt(self.head_dim)
attn_logits_queryonly = -(torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1)
+ torch.einsum('sbhd,sbhd->sbh', k_detach, k_detach).unsqueeze(0)
- 2 * torch.einsum('tbhd,sbhd->tsbh', q, k_detach)) / math.sqrt(self.head_dim)
attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
if attn_mask is not None:
attn_logits += attn_mask
attn_weights = F.softmax(attn_logits, dim=1)
attn_weights = update_attn_weights(attn_weights, attn_multiplier)
# Zero out the nonself weights.
selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
# Self-only gradient path wrt values.
PXA_vpath = torch.einsum("tsbh,sbhm->tbhm", selfonly_attn_weights.detach(), XA)
PXA_spath = torch.einsum("tsbh,sbhm->tbhm", attn_weights, XA.detach())
modified_PXA = PXA_spath + (PXA_vpath - PXA_vpath.detach())
PXA = PXA.detach() + (modified_PXA - modified_PXA.detach())
PXAV = torch.einsum("tbhm,mhd->tbhd", PXA, self.v_weight).reshape(T, N, self.embed_dim)
return self.out_proj(PXAV), attn_weights.detach()