def forward()

in xformers/triton/dropout.py [0:0]


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Convenience, catch a possible type or device mismatch
        if self.bias is not None:  # type: ignore
            self.bias = self.bias.to(dtype=x.dtype, device=x.device)  # type: ignore

        # Train/inference
        p = self.p if self.training else 0.0

        # This kernel is slower than pytorch for small buffers, bypassing it in that case
        perf_check = x.shape[-1] > 512

        # Catch a non-cuda setup, fallback to pytorch
        if not x.is_cuda or not perf_check:
            x = x + self.bias if self.bias is not None else x
            x = self.pytorch_activation(x)
            return torch.nn.functional.dropout(x, p)

        # The normal, Triton-backed path
        return _dropout.apply(
            x, p, self.bias, self.activation, self.activation_grad, True
        )