def __init__()

in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]


    def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
                 output_layer_init_method=None, layer_id=None, row_parallel_linear_final_bias=True, hooks={}, bias=True, activation_func=gelu, transformer_pointer=None, is_gated_mlp=False, num_experts=1,
                 params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
        super(MLP, self).__init__()
        self.layer_id = layer_id
        self.activation_func = activation_func
        # Set output layer initialization if not provided.
        if output_layer_init_method is None:
            output_layer_init_method = init_method
        self.hooks = hooks
        # Project to 4h.
        self.hidden_size = hidden_size
        if inner_hidden_size is None:
            inner_hidden_size = 4 * hidden_size
        self.inner_hidden_size = inner_hidden_size
        self.dense_h_to_4h = ColumnParallelLinear(
            self.hidden_size,
            self.inner_hidden_size,
            gather_output=False,
            init_method=init_method,
            bias=bias,
            params_dtype=params_dtype,
            module=self,
            name="dense_h_to_4h",
            skip_init=skip_init,
            device=device
        )
        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            self.inner_hidden_size,
            self.hidden_size,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            bias=bias,
            params_dtype=params_dtype,
            module=self,
            name="dense_4h_to_h",
            skip_init=skip_init,
            device=device,
            final_bias=row_parallel_linear_final_bias
        )
        self.is_gated_mlp = is_gated_mlp
        if is_gated_mlp:
            self.dense_h_to_4h_gate = ColumnParallelLinear(
            self.hidden_size,
            self.inner_hidden_size,
            gather_output=False,
            init_method=init_method,
            bias=False,
            params_dtype=params_dtype,
            module=self,
            name="dense_h_to_4h_gate",
            skip_init=skip_init,
            device=device
        )
        self.num_experts = num_experts
        for i in range(1, num_experts):
            self.register_module(f"dense_h_to_4h_{i}", ColumnParallelLinear(
                self.hidden_size,
                self.inner_hidden_size,
                gather_output=False,
                init_method=init_method,
                bias=bias,
                params_dtype=params_dtype,
                module=self,
                name=f"dense_h_to_4h_{i}",
                skip_init=skip_init,
                device=device
            ))
            # Project back to h.
            self.register_module(f"dense_4h_to_h_{i}", RowParallelLinear(
                self.inner_hidden_size,
                self.hidden_size,
                input_is_parallel=True,
                init_method=output_layer_init_method,
                bias=bias,
                params_dtype=params_dtype,
                module=self,
                name=f"dense_4h_to_h_{i}",
                skip_init=skip_init,
                device=device,
                final_bias=row_parallel_linear_final_bias
            ))
            if is_gated_mlp:
                self.register_module(f"dense_h_to_4h_gate_{i}", ColumnParallelLinear(
                self.hidden_size,
                self.inner_hidden_size,
                gather_output=False,
                init_method=init_method,
                bias=False,
                params_dtype=params_dtype,
                module=self,
                name=f"dense_h_to_4h_gate_{i}",
                skip_init=skip_init,
                device=device
            ))
        self.dropout = torch.nn.Dropout(output_dropout_prob)
        object.__setattr__(self, 'transformer', transformer_pointer)
        assert transformer_pointer is not None