in xformers/components/attention/feature_maps/softmax.py [0:0]
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
with record_function("feature_map::pre_scale"):
# Re-draw counting logic
if (
(
self.iter_before_redraw is not None
and self._iter_counter > self.iter_before_redraw
)
or self.features is None
or self.features.device != x.device
):
# The feature map is actually using half the dimension, we'll concatenate + and - features
self._iter_counter = 1
self.features = self._get_feature_map(
x.shape[-1], self.dim_feature_map, x.device
)
features = self.features
assert features is not None
if features.dtype != x.dtype:
self.features = features.to(x.dtype)
self._iter_counter += 1
# Normalization / softmax
if self.softmax_temp < 0:
# A = exp(QK.t/√d), so each input will be scaled by √√d
self.softmax_temp = x.shape[-1] ** -0.25
x_scaled = x * self.softmax_temp
# Compute the scaling factors in logspace, applied from within the exponential
# - dimnish possible exponential overflow
# - remove a multiply across the batch, replace by an addition
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
if self.normalize_inputs:
# L0 normalize the exponential term, can be useful for numerical stability
# This ensures that features +- offset is below 1
self.offset -= norm_x_2.max(1, keepdim=True)[0]
# Return the scaled inputs, the rest depends on the kernel being used
return x_scaled