in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
max_sequence_length,
embedding_dropout_prob=0,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path=0,
checkpoint_activations=False,
checkpoint_num_layers=1,
checkpoint_skip_layers=0,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
cross_hidden_size_per_attention_head=None,
layernorm_order='pre',
parallel_output=False,
is_decoder=False,
cross_attn_hidden_size=None,
use_bias=True,
use_qkv_bias=False,
num_multi_query_heads=0,
cross_num_multi_query_heads=0,
row_parallel_linear_final_bias=True,
activation_func=gelu,
is_gated_mlp=False,
is_rotary_emb=False,
num_experts=1,
layernorm=LayerNorm,
init_method=None,
use_final_layernorm=True,
hooks={},
params_dtype=torch.float,
skip_init=False,
device=torch.device('cpu')
):
super(BaseTransformer, self).__init__()
# recording parameters
self.hidden_size = hidden_size
self.inner_hidden_size = inner_hidden_size
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
self.is_decoder = is_decoder
self.cross_attn_hidden_size = cross_attn_hidden_size
self.cross_num_multi_query_heads = cross_num_multi_query_heads
if not is_decoder and cross_attn_hidden_size is not None:
print('warning: cross_attn_hidden_size is set but is_decoder is False')
self.use_bias = use_bias
self.use_qkv_bias = use_qkv_bias
self.num_multi_query_heads = num_multi_query_heads
self.is_gated_mlp = is_gated_mlp
self.is_rotary_emb = is_rotary_emb
self.num_experts = num_experts
self.use_final_layernorm = use_final_layernorm
self.layernorm_epsilon = layernorm_epsilon
self.parallel_output = parallel_output
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.checkpoint_skip_layers = checkpoint_skip_layers
assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
self.max_sequence_length = max_sequence_length
self.layernorm_order = layernorm_order
self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
self.hooks = copy.copy(hooks) # hooks will be updated each forward
object.__setattr__(self, 'transformer', self) # to give the default hooks the same api as outer hooks
# create embedding parameters
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
if vocab_size < 1000:
self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, dtype=params_dtype, device=device)
torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std)
else:
self.word_embeddings = VocabParallelEmbedding(
num_embeddings=vocab_size, embedding_dim=hidden_size,
params_dtype=params_dtype, skip_init=skip_init, device=device)
if self.is_rotary_emb:
from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
self.position_embeddings = FastRotaryEmbedding(hidden_size // num_attention_heads)
else:
self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
# create all layers
if init_method is None:
self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
self.init_method = unscaled_init_method(init_method_std)
else:
self.output_layer_init_method = init_method
self.init_method = init_method
def get_layer(layer_id):
return BaseTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
self.init_method,
layer_id,
inner_hidden_size=inner_hidden_size,
hidden_size_per_attention_head=hidden_size_per_attention_head,
cross_hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
output_layer_init_method=self.output_layer_init_method,
is_decoder=self.is_decoder,
cross_attn_hidden_size=cross_attn_hidden_size,
layernorm_order=layernorm_order,
layernorm=layernorm,
use_bias=use_bias,
use_qkv_bias=use_qkv_bias,
num_multi_query_heads=num_multi_query_heads,
cross_num_multi_query_heads=cross_num_multi_query_heads,
row_parallel_linear_final_bias=row_parallel_linear_final_bias,
drop_path=drop_path,
activation_func=activation_func,
is_gated_mlp=is_gated_mlp,
num_experts=num_experts,
hooks=self.hooks,
transformer_pointer=self,
params_dtype=params_dtype,
skip_init=skip_init,
device=device
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(num_layers)])
# Final layer norm before output.
if use_final_layernorm:
self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)