in models/spatial/attention.py [0:0]
def forward(self, x, attn_mask=None, rm_nonself_grads=False, attn_multiplier=None):
"""
Args:
x: (T, N, D)
attn_mask: (T, T) added to pre-softmax logits.
"""
T, N, _ = x.shape
q, k, v = map(lambda a: a.reshape(T, N, self.num_heads, self.head_dim), torch.split(self.in_proj(x), self.embed_dim, dim=-1))
attn_logits = torch.einsum('tbhd,sbhd->tsbh', q, k) / math.sqrt(self.head_dim)
if attn_mask is not None:
attn_mask = attn_mask[..., None, None]
attn_logits += attn_mask
attn_weights = F.softmax(attn_logits, dim=1) # (T, S, N, H)
attn_weights = update_attn_weights(attn_weights, attn_multiplier)
attn = torch.einsum("tsbh,sbhd->tbhd", attn_weights, v).reshape(T, N, -1)
if rm_nonself_grads:
# Construct self-only gradient paths wrt keys and queries.
attn_logits_keyonly = torch.einsum('tbhd,sbhd->tsbh', q.detach(), k) / math.sqrt(self.head_dim)
attn_logits_queryonly = torch.einsum('tbhd,sbhd->tsbh', q, k.detach()) / math.sqrt(self.head_dim)
attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
if attn_mask is not None:
attn_logits += attn_mask
attn_weights = F.softmax(attn_logits, dim=1)
attn_weights = update_attn_weights(attn_weights, attn_multiplier)
# Zero out the nonself weights.
selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
# Self-only gradient path wrt values.
attn_vpath = torch.einsum("tsbh,sbhd->tbhd", selfonly_attn_weights.detach(), v).reshape(T, N, -1)
attn_spath = torch.einsum("tsbh,sbhd->tbhd", attn_weights, v.detach()).reshape(T, N, -1)
modified_attn = attn_spath + (attn_vpath - attn_vpath.detach())
attn = attn.detach() + (modified_attn - modified_attn.detach())
attn = self.out_proj(attn)
return attn, attn_weights.detach()