in modules/SwissArmyTransformer/sat/model/official/gptneo_model.py [0:0]
def attention_fn(self, query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None, scaling_attention_score=False, **kwargs):
attention_type = self.attention_types[kwargs['layer_id']]
if attention_type not in ["global", "local"]:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `attention_type`: "
f"{attention_type}. Select attn layer types from ['global', 'local'] only."
)
# We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training.
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
if scaling_attention_score:
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
'''
2022/08/02
Difference to SAT-Base is the causal_mask.
'''
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
if attention_type == 'global':
bias = self.bias_global
else:
bias = self.bias_local
causal_mask = bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool).to(attention_scores.device)
mask_value = torch.finfo(attention_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attention_scores.dtype).to(attention_scores.device)
attention_scores = torch.where(causal_mask, attention_scores, mask_value)
if log_attention_weights is not None:
attention_scores += log_attention_weights
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
# if auto-regressive, skip
attention_scores = torch.where(attention_mask.to(attention_scores.device), attention_scores, mask_value)
# attention_scores = torch.mul(attention_scores, attention_mask) - \
# 10000.0 * (1.0 - attention_mask)
attention_probs = F.softmax(attention_scores, dim=-1)
if attention_dropout is not None:
if mpu.get_cuda_rng_tracker is not None:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = attention_dropout(attention_probs)
else:
attention_probs = attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
return context_layer