def forward()

in models/spatial/attention.py [0:0]


    def forward(self, x, attn_mask=None, rm_nonself_grads=False, attn_multiplier=None):
        """
        Args:
            x: (T, N, D)
            attn_mask: (T, T) added to pre-softmax logits.
        """

        T, N, _ = x.shape

        q, k, v = map(lambda a: a.reshape(T, N, self.num_heads, self.head_dim), torch.split(self.in_proj(x), self.embed_dim, dim=-1))
        attn_logits = torch.einsum('tbhd,sbhd->tsbh', q, k) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn_mask = attn_mask[..., None, None]
            attn_logits += attn_mask
        attn_weights = F.softmax(attn_logits, dim=1)  # (T, S, N, H)
        attn_weights = update_attn_weights(attn_weights, attn_multiplier)

        attn = torch.einsum("tsbh,sbhd->tbhd", attn_weights, v).reshape(T, N, -1)

        if rm_nonself_grads:
            # Construct self-only gradient paths wrt keys and queries.
            attn_logits_keyonly = torch.einsum('tbhd,sbhd->tsbh', q.detach(), k) / math.sqrt(self.head_dim)
            attn_logits_queryonly = torch.einsum('tbhd,sbhd->tsbh', q, k.detach()) / math.sqrt(self.head_dim)

            attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
            attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
            if attn_mask is not None:
                attn_logits += attn_mask
            attn_weights = F.softmax(attn_logits, dim=1)
            attn_weights = update_attn_weights(attn_weights, attn_multiplier)

            # Zero out the nonself weights.
            selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
            selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
            # Self-only gradient path wrt values.
            attn_vpath = torch.einsum("tsbh,sbhd->tbhd", selfonly_attn_weights.detach(), v).reshape(T, N, -1)
            attn_spath = torch.einsum("tsbh,sbhd->tbhd", attn_weights, v.detach()).reshape(T, N, -1)

            modified_attn = attn_spath + (attn_vpath - attn_vpath.detach())

            attn = attn.detach() + (modified_attn - modified_attn.detach())

        attn = self.out_proj(attn)
        return attn, attn_weights.detach()