in modules/SwissArmyTransformer/sat/model/transformer.py [0:0]
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, layer_id, hidden_size_per_attention_head=None, output_layer_init_method=None, bias=True, qkv_bias=False, num_multi_query_heads=0, row_parallel_linear_final_bias=True,
hooks={}, transformer_pointer=None, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')):
super(SelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
self.hooks = hooks
self.layer_id = layer_id
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_multi_query_heads = num_multi_query_heads
if hidden_size_per_attention_head is None:
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
else:
self.hidden_size_per_attention_head = hidden_size_per_attention_head
self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
self.num_multi_query_heads_per_partition = divide(num_multi_query_heads, world_size)
self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
self.hidden_size_per_partition = self.hidden_size_per_attention_head * self.num_attention_heads_per_partition
# Strided linear layer.
if num_multi_query_heads == 0:
qkv_size = 3 * self.inner_hidden_size
self.stride = 3
else: # multi-query
qkv_size = self.inner_hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
self.stride = [self.num_attention_heads_per_partition, self.num_multi_query_heads_per_partition, self.num_multi_query_heads_per_partition]
self.query_key_value = ColumnParallelLinear(
hidden_size,
qkv_size,
stride=self.stride,
gather_output=False,
init_method=init_method,
bias=bias or qkv_bias,
params_dtype=params_dtype,
module=self,
name="query_key_value",
skip_init=skip_init,
device=device
)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
self.dense = RowParallelLinear(
self.inner_hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
bias=bias,
params_dtype=params_dtype,
module=self,
name="dense",
skip_init=skip_init,
device=device,
final_bias=row_parallel_linear_final_bias
)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
object.__setattr__(self, 'transformer', transformer_pointer)
assert transformer_pointer is not None