modules/SwissArmyTransformer/sat/ops/memory_efficient_attention.py (14 lines of code) (raw):
from xformers.ops import memory_efficient_attention as mea
from xformers.ops import LowerTriangularMask
def memory_efficient_attention(q, k, v, attention_dropout=0., mask=None, scale=None):
if mask is None:
attn_bias = None
else:
t = (mask > 0)
if t.all():
attn_bias = None
elif not t.triu(diagonal=1).any() and t.tril().all():
attn_bias = LowerTriangularMask()
else:
raise ValueError(f"Unknown mask type {mask}")
return mea(q, k, v, attn_bias=attn_bias, p=attention_dropout, scale=scale)