models/spatial/attention.py [52:61]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
            attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
            if attn_mask is not None:
                attn_logits += attn_mask
            attn_weights = F.softmax(attn_logits, dim=1)
            attn_weights = update_attn_weights(attn_weights, attn_multiplier)

            # Zero out the nonself weights.
            selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
            selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/spatial/attention.py [128:137]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            attn_logits_keyonly = SelfonlyGradients.apply(attn_logits_keyonly)
            attn_logits = attn_logits_queryonly + (attn_logits_keyonly - attn_logits_keyonly.detach())
            if attn_mask is not None:
                attn_logits += attn_mask
            attn_weights = F.softmax(attn_logits, dim=1)
            attn_weights = update_attn_weights(attn_weights, attn_multiplier)

            # Zero out the nonself weights.
            selfonly_mask = ~(torch.triu(torch.ones(T, T), diagonal=1) + torch.tril(torch.ones(T, T), diagonal=-1)).bool()
            selfonly_attn_weights = attn_weights * selfonly_mask[..., None, None].to(attn_weights.device)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



