in optimum/bettertransformer/models/encoder_models.py [0:0]
def __init__(self, bert_layer, config):
r"""
A simple conversion of the BERT layer to its `BetterTransformer` implementation.
Args:
bert_layer (`torch.nn.Module`):
The original BERT Layer where the weights needs to be retrieved.
"""
super().__init__(config)
super(BetterTransformerBaseLayer, self).__init__()
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
bert_layer.attention.self.query.weight,
bert_layer.attention.self.key.weight,
bert_layer.attention.self.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
bert_layer.attention.self.query.bias,
bert_layer.attention.self.key.bias,
bert_layer.attention.self.value.bias,
]
)
)
# Out proj layer
self.out_proj_weight = bert_layer.attention.output.dense.weight
self.out_proj_bias = bert_layer.attention.output.dense.bias
# Linear layer 1
self.linear1_weight = bert_layer.intermediate.dense.weight
self.linear1_bias = bert_layer.intermediate.dense.bias
# Linear layer 2
self.linear2_weight = bert_layer.output.dense.weight
self.linear2_bias = bert_layer.output.dense.bias
# Layer norm 1
self.norm1_eps = bert_layer.attention.output.LayerNorm.eps
self.norm1_weight = bert_layer.attention.output.LayerNorm.weight
self.norm1_bias = bert_layer.attention.output.LayerNorm.bias
# Layer norm 2
self.norm2_eps = bert_layer.output.LayerNorm.eps
self.norm2_weight = bert_layer.output.LayerNorm.weight
self.norm2_bias = bert_layer.output.LayerNorm.bias
# Model hyper parameters
self.num_heads = bert_layer.attention.self.num_attention_heads
self.embed_dim = bert_layer.attention.self.all_head_size
# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.original_layers_mapping = {
"in_proj_weight": [
"attention.self.query.weight",
"attention.self.key.weight",
"attention.self.value.weight",
],
"in_proj_bias": ["attention.self.query.bias", "attention.self.key.bias", "attention.self.value.bias"],
"out_proj_weight": "attention.output.dense.weight",
"out_proj_bias": "attention.output.dense.bias",
"linear1_weight": "intermediate.dense.weight",
"linear1_bias": "intermediate.dense.bias",
"linear2_weight": "output.dense.weight",
"linear2_bias": "output.dense.bias",
"norm1_eps": "attention.output.LayerNorm.eps",
"norm1_weight": "attention.output.LayerNorm.weight",
"norm1_bias": "attention.output.LayerNorm.bias",
"norm2_eps": "output.LayerNorm.eps",
"norm2_weight": "output.LayerNorm.weight",
"norm2_bias": "output.LayerNorm.bias",
}
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.hidden_dropout_prob = config.hidden_dropout_prob
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
self.act_fn_callable = ACT2FN[self.act_fn]
self.validate_bettertransformer()