modules/SwissArmyTransformer/sat/model/finetune/adapter.py (45 lines of code) (raw):
# -*- encoding: utf-8 -*-
# @File : adapter.py
# @Time : 2022/6/16
# @Author : Zhuoyi Yang
# @Contact : yangzhuo18@mails.tsinghua.edu.cn
from sat.model.base_model import BaseModel, BaseMixin, non_conflict
import torch.nn as nn
class AdapterMixin(BaseMixin):
def __init__(self, num_layers, hidden_size, adapter_hidden):
super().__init__()
self.ff1 = nn.ModuleList([
nn.Linear(hidden_size, adapter_hidden) for _ in range(num_layers)
])
self.ff2 = nn.ModuleList([
nn.Linear(adapter_hidden, hidden_size) for _ in range(num_layers)
])
self.ff3 = nn.ModuleList([
nn.Linear(hidden_size, adapter_hidden) for _ in range(num_layers)
])
self.ff4 = nn.ModuleList([
nn.Linear(adapter_hidden, hidden_size) for _ in range(num_layers)
])
def layer_forward(self, hidden_states, mask, *args, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
'''
layer = self.transformer.layers[kw_args['layer_id']]
# Layer norm at the begining of the transformer layer.
hidden_states = layer.input_layernorm(hidden_states)
# Self attention.
attention_output = layer.attention(hidden_states, mask, **kw_args)
attention_output = attention_output + self.ff2[kw_args['layer_id']](nn.functional.gelu(self.ff1[kw_args['layer_id']](attention_output)))
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = layer.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = layer.mlp(layernorm_output, **kw_args)
mlp_output = mlp_output + self.ff4[kw_args['layer_id']](nn.functional.gelu(self.ff3[kw_args['layer_id']](mlp_output)))
# Second residual connection.
output = layernorm_output + mlp_output
return output
def reinit(self, parent_model=None):
# refer to https://github.com/google-research/adapter-bert/blob/1a31fc6e92b1b89a6530f48eb0f9e1f04cc4b750/modeling.py#L321
for ly in self.ff1:
nn.init.trunc_normal_(ly.weight, std=1e-3)
nn.init.zeros_(ly.bias)
for ly in self.ff2:
nn.init.trunc_normal_(ly.weight, std=1e-3)
nn.init.zeros_(ly.bias)
for ly in self.ff3:
nn.init.trunc_normal_(ly.weight, std=1e-3)
nn.init.zeros_(ly.bias)
for ly in self.ff4:
nn.init.trunc_normal_(ly.weight, std=1e-3)
nn.init.zeros_(ly.bias)