def forward()

in models/spatial/attention.py [0:0]


    def forward(self, x, attn_mask=None, rm_nonself_grads=False, attn_multiplier=None):
        """
        Args:
            x: (T, N, D)
            attn_mask: (T, T) added to pre-softmax logits.
        """

        T, N, _ = x.shape

        q = k = torch.einsum("tbm,mhd->tbhd", x, self.q_weight)
        squared_dist = (torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1)
                        + torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0)
                        - 2 * torch.einsum('tbhd,sbhd->tsbh', q, k))
        attn_logits = -squared_dist / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn_mask = attn_mask[..., None, None]
            attn_logits += attn_mask
        attn_weights = F.softmax(attn_logits, dim=1)  # (T, S, N, H)
        attn_weights = update_attn_weights(attn_weights, attn_multiplier)
        A = torch.einsum("mhd,nhd->hmn", self.q_weight, self.q_weight) / math.sqrt(self.head_dim)
        XA = torch.einsum("tbm,hmn->tbhn", x, A)
        PXA = torch.einsum("tsbh,sbhm->tbhm", attn_weights, XA)

        if rm_nonself_grads:
            # Construct self-only gradient paths wrt keys and queries.
            q_detach = q.detach()
            k_detach = k.detach()
            attn_logits_keyonly = -(torch.einsum('tbhd,tbhd->tbh', q_detach, q_detach).unsqueeze(1)
                                    + torch.einsum('sbhd,sbhd->sbh', k, k).unsqueeze(0)
                                    - 2 * torch.einsum('tbhd,sbhd->tsbh', q_detach, k)) / math.sqrt(self.head_dim)
            attn_logits_queryonly = -(torch.einsum('tbhd,tbhd->tbh', q, q).unsqueeze(1)
                                      + torch.einsum('sbhd,sbhd->sbh', k_detach, k_detach).unsqueeze(0)
                                      - 2 * torch.einsum('tbhd,sbhd->tsbh', q, k_detach)) / math.sqrt(self.head_dim)

            attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
            attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
            if attn_mask is not None:
                attn_logits += attn_mask
            attn_weights = F.softmax(attn_logits, dim=1)
            attn_weights = update_attn_weights(attn_weights, attn_multiplier)

            # Zero out the nonself weights.
            selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
            selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
            # Self-only gradient path wrt values.
            PXA_vpath = torch.einsum("tsbh,sbhm->tbhm", selfonly_attn_weights.detach(), XA)
            PXA_spath = torch.einsum("tsbh,sbhm->tbhm", attn_weights, XA.detach())

            modified_PXA = PXA_spath + (PXA_vpath - PXA_vpath.detach())
            PXA = PXA.detach() + (modified_PXA - modified_PXA.detach())

        PXAV = torch.einsum("tbhm,mhd->tbhd", PXA, self.v_weight).reshape(T, N, self.embed_dim)
        return self.out_proj(PXAV), attn_weights.detach()