modules/SwissArmyTransformer/sat/model/finetune/lora.py (76 lines of code) (raw):

# -*- encoding: utf-8 -*- # @File : lora.py # @Time : 2022/6/16 # @Author : Zhuoyi Yang # @Contact : yangzhuo18@mails.tsinghua.edu.cn # -*- encoding: utf-8 -*- ''' @File : prompt_tuning.py @Time : 2021/12/12 20:45:18 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import torch import torch.nn as nn from sat.model.transformer import standard_attention from sat.model.base_model import BaseModel, BaseMixin, non_conflict from sat.mpu.utils import split_tensor_along_last_dim import torch.nn.functional as F class LoRAMixin(BaseMixin): def __init__( self, hidden_size: int, layer_num: int = 24, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., layer_range = None, ): super().__init__() # Actual trainable parameters self.r = r self.lora_alpha = lora_alpha if lora_dropout and lora_dropout > 0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x if layer_range is None: layer_range = [i for i in range(layer_num)] self.layer_range = layer_range self.lora_linear = nn.ModuleList([ nn.ParameterDict() for layer_id in range(layer_num) ]) matrices = ["Q", "K", "V", "O"] for i in layer_range: for matrix in matrices: self.lora_linear[i][matrix+"_A"] = nn.Parameter(torch.zeros((r, hidden_size))) self.lora_linear[i][matrix+"_B"] = nn.Parameter(torch.zeros((hidden_size, r))) nn.init.kaiming_uniform_(self.lora_linear[i][matrix+"_A"], a=math.sqrt(5)) nn.init.zeros_(self.lora_linear[i][matrix+"_B"]) self.scaling = self.lora_alpha / self.r def attention_forward(self, hidden_states, mask, layer_id, **kw_args): attention_fn = standard_attention if 'attention_fn' in self.transformer.hooks: attention_fn = self.transformer.hooks['attention_fn'] layer = self.transformer.layers[layer_id].attention lora_layer = self.lora_linear[layer_id] mixed_raw_layer = layer.query_key_value(hidden_states) (mixed_query_layer, mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) if layer_id in self.layer_range: mixed_query_layer = mixed_query_layer + (self.lora_dropout(hidden_states) @ lora_layer["Q_A"].T @ lora_layer["Q_B"].T) * self.scaling mixed_key_layer = mixed_key_layer + (self.lora_dropout(hidden_states) @ lora_layer["K_A"].T @ lora_layer["K_B"].T) * self.scaling mixed_value_layer = mixed_value_layer + (self.lora_dropout(hidden_states) @ lora_layer["V_A"].T @ lora_layer["V_B"].T) * self.scaling dropout_fn = layer.attention_dropout if self.training else None query_layer = layer._transpose_for_scores(mixed_query_layer) key_layer = layer._transpose_for_scores(mixed_key_layer) value_layer = layer._transpose_for_scores(mixed_value_layer) context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (layer.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = layer.dense(context_layer) if layer_id in self.layer_range: output = output + (self.lora_dropout(context_layer) @ lora_layer["O_A"].T @ lora_layer["O_B"].T ) * self.scaling if self.training: output = layer.output_dropout(output) return output