in xformers/triton/dropout.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
if self.bias is not None: # type: ignore
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
# Train/inference
p = self.p if self.training else 0.0
# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512
# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda or not perf_check:
x = x + self.bias if self.bias is not None else x
x = self.pytorch_activation(x)
return torch.nn.functional.dropout(x, p)
# The normal, Triton-backed path
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)