in modules/SwissArmyTransformer/sat/transformer_defaults.py [0:0]
def layer_forward_default(self, hidden_states, mask, *args, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
'''
self = self.transformer.layers[kw_args['layer_id']]
# Layer norm at the begining of the transformer layer.
attention_input = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(attention_input, mask, **kw_args)
# Third LayerNorm
if self.layernorm_order == 'sandwich':
attention_output = self.third_layernorm(attention_output)
# DropPath for attention
if self.training and self.drop_path > 0.:
# drop_path percentage 0, others 1/(1-p)
random_tensor = (1-self.drop_path
+ torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path)
attention_output = random_tensor.view(-1, 1, 1) * attention_output
# Residual connection.
if self.layernorm_order == 'post':
hidden_states = attention_input + attention_output
mlp_input = self.post_attention_layernorm(hidden_states)
else:
hidden_states = hidden_states + attention_output
if self.is_decoder:
encoder_outputs = kw_args['encoder_outputs']
if encoder_outputs is not None:
assert 'cross_attention_mask' in kw_args
# Cross attention
if self.layernorm_order == 'post':
attention_output = self.cross_attention(mlp_input, **kw_args)
# Residual connection.
hidden_states = mlp_input + attention_output
# Layer norm post the cross attention
mlp_input = self.post_cross_attention_layernorm(hidden_states)
else:
cross_input = self.post_cross_attention_layernorm(hidden_states)
attention_output = self.cross_attention(cross_input, **kw_args)
hidden_states = hidden_states + attention_output
if self.layernorm_order != 'post':
mlp_input = self.post_attention_layernorm(hidden_states)
# MLP.
mlp_output = self.mlp(mlp_input, **kw_args)
# Fourth LayerNorm
if self.layernorm_order == 'sandwich':
mlp_output = self.fourth_layernorm(mlp_output)
# DropPath for mlp
if self.training and self.drop_path > 0.:
random_tensor = (1-self.drop_path
+ torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path)
mlp_output = random_tensor.view(-1, 1, 1) * mlp_output
# Second residual connection.
if self.layernorm_order == 'post':
output = mlp_input + mlp_output
else:
output = hidden_states + mlp_output
return output