in xformers/components/attention/ortho.py [0:0]
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
"""
Construct set of landmarks by recursively selecting new landmarks
that are maximally orthogonal to the existing set.
Returns near orthogonal landmarks with shape (B, M, D).
"""
if self.subsample_fraction < 1.0:
# Need at least M samples of queries
num_samples = max(
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
)
q_samples = q[
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
]
else:
# (B, N, D)
q_samples = q
# may need to change default eps to eps=1e-8 for mixed precision compatibility
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
B, N, D = q_samples_normalized.shape
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
landmark_mask = torch.ones(
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
)
# Get initial random landmark
random_idx = torch.randint(
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
)
selected_mask.scatter_(-2, random_idx, landmark_mask)
# Selected landmarks
selected_landmarks = torch.empty(
(B, self.num_landmarks, D),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
selected_landmarks[:, 0, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
].view(B, D)
# Store computed cosine similarities
cos_sims = torch.empty(
(B, N, self.num_landmarks),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
for M in range(1, self.num_landmarks):
with profiler.record_function("find new landmark"):
# Calculate absolute cosine similarity between selected and unselected landmarks
# (B, N, D) * (B, D) -> (B, N)
cos_sims[:, :, M - 1] = torch.einsum(
"b n d, b d -> b n",
q_samples_normalized,
selected_landmarks[:, M - 1, :],
).abs()
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
cos_sim_set = cos_sims[:, :, :M]
# Get orthogonal landmark: landmark with smallest absolute cosine similarity:
# set cosine similarity for already selected landmarks to > 1
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
# (B,) - want max for non
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
# Add most orthogonal landmark to selected landmarks:
selected_landmarks[:, M, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
].view(B, D)
# Removed selected indices from non-selected mask:
selected_mask.scatter_(
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
)
# (B, M, D)
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
B, -1, D
)
return landmarks # (B, M, D)