def pre_scale()

in xformers/components/attention/feature_maps/softmax.py [0:0]


    def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
        with record_function("feature_map::pre_scale"):
            # Re-draw counting logic
            if (
                (
                    self.iter_before_redraw is not None
                    and self._iter_counter > self.iter_before_redraw
                )
                or self.features is None
                or self.features.device != x.device
            ):
                # The feature map is actually using half the dimension, we'll concatenate + and - features
                self._iter_counter = 1
                self.features = self._get_feature_map(
                    x.shape[-1], self.dim_feature_map, x.device
                )

            features = self.features
            assert features is not None

            if features.dtype != x.dtype:
                self.features = features.to(x.dtype)

            self._iter_counter += 1

            # Normalization / softmax
            if self.softmax_temp < 0:
                # A = exp(QK.t/√d), so each input will be scaled by √√d
                self.softmax_temp = x.shape[-1] ** -0.25

            x_scaled = x * self.softmax_temp

            # Compute the scaling factors in logspace, applied from within the exponential
            # - dimnish possible exponential overflow
            # - remove a multiply across the batch, replace by an addition
            norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
            self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon

            if self.normalize_inputs:
                # L0 normalize the exponential term, can be useful for numerical stability
                # This ensures that features +- offset is below 1
                self.offset -= norm_x_2.max(1, keepdim=True)[0]

        # Return the scaled inputs, the rest depends on the kernel being used
        return x_scaled