in modules/SwissArmyTransformer/sat/transformer_defaults.py [0:0]
def attention_fn_default(query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
# expand head dim to query dim, if necessary
# only useful for multi-query attention
batch_size, num_query_heads = query_layer.shape[:2] # [b, np, s, hn]
num_kv_heads = key_layer.shape[1] # [b, np, s, hn]
key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:])
value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:])
is_low_triangle = (attention_mask == torch.ones_like(attention_mask, dtype=torch.float).tril()).all()
is_full = (attention_mask is None) or (attention_mask > 0).all()
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
# Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None:
context = mpu.get_cuda_rng_tracker().fork()
else:
context = contextlib.nullcontext()
with context:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer,
attn_mask=None,
dropout_p=dropout_p,
is_causal=not is_full
)
return attn_output
else:
return standard_attention(
query_layer, key_layer, value_layer, attention_mask,
attention_dropout=attention_dropout, log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score, **kwargs
)