in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]
def __init__(self, hidden_size, output_dropout_prob, init_method, inner_hidden_size=None,
output_layer_init_method=None, layer_id=None, row_parallel_linear_final_bias=True, hooks={}, bias=True, activation_func=gelu, transformer_pointer=None, is_gated_mlp=False, num_experts=1,
params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(MLP, self).__init__()
self.layer_id = layer_id
self.activation_func = activation_func
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
# Project to 4h.
self.hidden_size = hidden_size
if inner_hidden_size is None:
inner_hidden_size = 4 * hidden_size
self.inner_hidden_size = inner_hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h",
skip_init=skip_init,
device=device
)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense_4h_to_h",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.is_gated_mlp = is_gated_mlp
if is_gated_mlp:
self.dense_h_to_4h_gate = ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name="dense_h_to_4h_gate",
skip_init=skip_init,
device=device
)
self.num_experts = num_experts
for i in range(1, num_experts):
self.register_module(f"dense_h_to_4h_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_{i}",
skip_init=skip_init,
device=device
))
# Project back to h.
self.register_module(f"dense_4h_to_h_{i}", RowParallelLinear(
self.inner_hidden_size,
self.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name=f"dense_4h_to_h_{i}",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
))
if is_gated_mlp:
self.register_module(f"dense_h_to_4h_gate_{i}", ColumnParallelLinear(
self.hidden_size,
self.inner_hidden_size,
gather_output=False,
init_method=init_method,
bias=False,
params_dtype=params_dtype,
module=self,
name=f"dense_h_to_4h_gate_{i}",
skip_init=skip_init,
device=device
))
self.dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None