in modules/SwissArmyTransformer/sat/transformer_defaults.py [0:0]
def standard_attention(query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
# We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training.
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
if scaling_attention_score:
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if log_attention_weights is not None:
attention_scores += log_attention_weights
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
# if auto-regressive, skip
attention_scores = torch.mul(attention_scores, attention_mask) - \
10000.0 * (1.0 - attention_mask)
attention_probs = F.softmax(attention_scores, dim=-1)
if attention_dropout is not None:
if mpu.get_cuda_rng_tracker is not None:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = attention_dropout(attention_probs)
else:
attention_probs = attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
return context_layer