def __init__()

in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]


    def __init__(self,
                 num_layers,
                 vocab_size,
                 hidden_size,
                 num_attention_heads,
                 max_sequence_length,
                 embedding_dropout_prob=0,
                 attention_dropout_prob=0,
                 output_dropout_prob=0,
                 drop_path=0,
                 checkpoint_activations=False,
                 checkpoint_num_layers=1,
                 checkpoint_skip_layers=0,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 inner_hidden_size=None,
                 hidden_size_per_attention_head=None,
                 cross_hidden_size_per_attention_head=None,
                 layernorm_order='pre',
                 parallel_output=False,
                 is_decoder=False,
                 cross_attn_hidden_size=None,
                 use_bias=True,
                 use_qkv_bias=False,
                 num_multi_query_heads=0,
                 cross_num_multi_query_heads=0,
                 row_parallel_linear_final_bias=True,
                 activation_func=gelu,
                 is_gated_mlp=False,
                 is_rotary_emb=False,
                 num_experts=1,
                 layernorm=LayerNorm,
                 init_method=None,
                 use_final_layernorm=True,
                 hooks={},
                 params_dtype=torch.float,
                 skip_init=False,
                 device=torch.device('cpu')
                 ):
        super(BaseTransformer, self).__init__()

        # recording parameters
        self.hidden_size = hidden_size
        self.inner_hidden_size = inner_hidden_size
        self.hidden_size_per_attention_head = hidden_size_per_attention_head
        self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head
        self.is_decoder = is_decoder
        self.cross_attn_hidden_size = cross_attn_hidden_size
        self.cross_num_multi_query_heads = cross_num_multi_query_heads
        if not is_decoder and cross_attn_hidden_size is not None:
            print('warning: cross_attn_hidden_size is set but is_decoder is False')
        self.use_bias = use_bias
        self.use_qkv_bias = use_qkv_bias
        self.num_multi_query_heads = num_multi_query_heads
        self.is_gated_mlp = is_gated_mlp
        self.is_rotary_emb = is_rotary_emb
        self.num_experts = num_experts
        self.use_final_layernorm = use_final_layernorm
        self.layernorm_epsilon = layernorm_epsilon
        self.parallel_output = parallel_output
        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers
        self.checkpoint_skip_layers = checkpoint_skip_layers
        assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.'
        self.max_sequence_length = max_sequence_length
        self.layernorm_order = layernorm_order
        self.row_parallel_linear_final_bias = row_parallel_linear_final_bias
        self.hooks = copy.copy(hooks)  # hooks will be updated each forward
        object.__setattr__(self, 'transformer', self) # to give the default hooks the same api as outer hooks

        # create embedding parameters
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

        if vocab_size < 1000:
            self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, dtype=params_dtype, device=device)
            torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std)
        else:
            self.word_embeddings = VocabParallelEmbedding(
                num_embeddings=vocab_size, embedding_dim=hidden_size, 
                params_dtype=params_dtype, skip_init=skip_init, device=device)

        if self.is_rotary_emb:
            from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
            self.position_embeddings = FastRotaryEmbedding(hidden_size // num_attention_heads)
        else:
            self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
            torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)

        # create all layers
        if init_method is None:
            self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
            self.init_method = unscaled_init_method(init_method_std)
        else:
            self.output_layer_init_method = init_method
            self.init_method = init_method

        def get_layer(layer_id):
            return BaseTransformerLayer(
                hidden_size,
                num_attention_heads,
                attention_dropout_prob,
                output_dropout_prob,
                layernorm_epsilon,
                self.init_method,
                layer_id,
                inner_hidden_size=inner_hidden_size,
                hidden_size_per_attention_head=hidden_size_per_attention_head,
                cross_hidden_size_per_attention_head=cross_hidden_size_per_attention_head,
                output_layer_init_method=self.output_layer_init_method,
                is_decoder=self.is_decoder,
                cross_attn_hidden_size=cross_attn_hidden_size,
                layernorm_order=layernorm_order,
                layernorm=layernorm,
                use_bias=use_bias,
                use_qkv_bias=use_qkv_bias,
                num_multi_query_heads=num_multi_query_heads,
                cross_num_multi_query_heads=cross_num_multi_query_heads,
                row_parallel_linear_final_bias=row_parallel_linear_final_bias,
                drop_path=drop_path,
                activation_func=activation_func,
                is_gated_mlp=is_gated_mlp,
                num_experts=num_experts,
                hooks=self.hooks,
                transformer_pointer=self,
                params_dtype=params_dtype,
                skip_init=skip_init,
                device=device
            )

        self.layers = torch.nn.ModuleList(
            [get_layer(layer_id) for layer_id in range(num_layers)])

        # Final layer norm before output.
        if use_final_layernorm:
            self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)