modules/SwissArmyTransformer/sat/model/finetune/ffadd.py (41 lines of code) (raw):

# -*- encoding: utf-8 -*- # @File : ffadd.py # @Time : 2022/6/16 # @Author : Zhuoyi Yang # @Contact : yangzhuo18@mails.tsinghua.edu.cn from sat.model.base_model import BaseMixin, non_conflict import torch class FFADDMixin(BaseMixin): def __init__( self, hidden_size: int, layer_num: int = 24, r: int = 0, layer_range = None, ): super().__init__() # Actual trainable parameters self.r = r self.ffadd_linear = nn.ModuleList([ nn.ModuleList() for layer_id in range(layer_num) ]) if layer_range is None: layer_range = [i for i in range(layer_num)] self.layer_range = layer_range for i in layer_range: self.ffadd_linear[i].append(torch.nn.Linear(hidden_size, r, bias=True)) self.ffadd_linear[i].append(torch.nn.Linear(r, hidden_size, bias=True)) nn.init.zeros_(self.ffadd_linear[i][1].weight) nn.init.zeros_(self.ffadd_linear[i][1].bias) def mlp_forward(self, hidden_states, layer_id, attention_output = None, **kw_args): layer = self.transformer.layers[layer_id].mlp intermediate_parallel = layer.dense_h_to_4h(hidden_states) intermediate_parallel = layer.activation_func(intermediate_parallel) output = layer.dense_4h_to_h(intermediate_parallel) if layer_id in self.layer_range: ffadd_layer = self.ffadd_linear[layer_id] layer = self.transformer.layers[layer_id].mlp intermediate_add = ffadd_layer[0](hidden_states) intermediate_add = layer.activation_func(intermediate_add) if attention_output is not None: kw_args["output_this_layer"]["0"] = intermediate_add.data.cpu().numpy() output2 = ffadd_layer[1](intermediate_add) output = output + output2 if self.training: output = layer.dropout(output) return output