modules/SwissArmyTransformer/sat/model/finetune/lora2.py (221 lines of code) (raw):

""" In this mixin, I use a different implementation than sat/model/finetune/lora.py I just use a fake linear layer to replace any model with lora mixin. """ import torch import torch.nn as nn from sat.model.base_model import BaseMixin import math from sat.helpers import print_all, print_rank0 from sat.model.transformer import RowParallelLinear, ColumnParallelLinear from sat.mpu.layers import copy_to_model_parallel_region from sat import mpu class HackLinear(nn.Linear): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + 'weight' in state_dict: self.weight.data.copy_(state_dict[prefix+'weight']) if prefix + 'bias' in state_dict: self.bias.data.copy_(state_dict[prefix+'bias']) class HackRowParallelLinear(RowParallelLinear): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + 'weight' in state_dict: self.weight.data.copy_(state_dict[prefix+'weight']) if prefix + 'bias' in state_dict: self.bias.data.copy_(state_dict[prefix+'bias']) class HackColumnParallelLinear(ColumnParallelLinear): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + 'weight' in state_dict: self.weight.data.copy_(state_dict[prefix+'weight']) if prefix + 'bias' in state_dict: self.bias.data.copy_(state_dict[prefix+'bias']) try: from bitsandbytes.nn import LinearNF4 from bitsandbytes.functional import QuantState class HackLinearNF4(LinearNF4): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + 'weight' in state_dict: self.weight.data.copy_(state_dict[prefix+'weight']) if self.weight.data.dtype == torch.uint8: quant_dict = {k[len(prefix+'weight.'):]: v for k,v in state_dict.items() if k.startswith(prefix+'weight.')} self.weight.quant_state = QuantState.from_dict(quant_dict, device=self.weight.data.device) if prefix + 'bias' in state_dict: self.bias.data.copy_(state_dict[prefix+'bias']) except Exception as exception: print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING') class HackParameterList(nn.ParameterList): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): for i in range(len(self)): if prefix + str(i) in state_dict: self[i].data.copy_(state_dict[prefix+str(i)]) map_cls = { nn.Linear: (HackLinear, {}), ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}), RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True}) } class LoraLinear(nn.Module): def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., qlora=False, original_obj=None): super().__init__() assert original_obj is not None, "original linear object must be given!" if lora_dropout and lora_dropout > 0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x self.r = r self.lora_alpha = lora_alpha self.scaling = self.lora_alpha / self.r bias = original_obj.bias is not None dtype = original_obj.weight.dtype if qlora: try: self.original = HackLinearNF4(in_dim, out_dim, bias=bias) except: raise Exception('Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.') else: base_cls, kwargs = map_cls[original_cls] if original_cls is ColumnParallelLinear: kwargs['stride'] = partition kwargs['skip_init'] = True kwargs['params_dtype'] = dtype elif original_cls is RowParallelLinear: kwargs['final_bias'] = original_obj.final_bias kwargs['skip_init'] = True kwargs['params_dtype'] = dtype else: kwargs['dtype'] = dtype self.original = base_cls(in_dim, out_dim, **kwargs, bias=bias) self.original.weight.data.copy_(original_obj.weight.data.detach().clone()) if bias: self.original.bias.data.copy_(original_obj.bias.data.detach().clone()) if type(partition) is int: self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in range(partition)]) self.matrix_B = HackParameterList([nn.Parameter(torch.empty((original_obj.weight.shape[0] // partition, r), dtype=dtype)) for _ in range(partition)]) for i in range(partition): nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5)) nn.init.zeros_(self.matrix_B[i]) self.matrix_B[i].model_parallel = True self.matrix_B[i].tensor_model_parallel = True else: new_sizes = [original_obj.weight.shape[0] // sum(partition) * i for i in partition] self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, original_obj.weight.shape[1]), dtype=dtype)) for _ in partition]) self.matrix_B = HackParameterList([nn.Parameter(torch.empty((sz, r), dtype=dtype)) for sz in new_sizes]) for i in range(len(partition)): nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5)) nn.init.zeros_(self.matrix_B[i]) self.matrix_B[i].model_parallel = True self.matrix_B[i].tensor_model_parallel = True self.partition = partition def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): # This is not a perfect version, becuase it doesn't handle errors and unexpected keys. if prefix + 'weight' in state_dict: # load from normal Linear self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) else: # load from LoraLinear super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def forward(self, x): mixed_raw_layer = self.original(x) x = self.lora_dropout(x) lora_outputs = [] for mA, mB in zip(self.matrix_A, self.matrix_B): lora_outputs.append((copy_to_model_parallel_region(x @ mA.T) @ mB.T) * self.scaling) mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1) return mixed_raw_layer def replace_linear_with_lora(lin, partition, r, *args, **kw_args): if kw_args.get('in_size', None) is not None: in_size = kw_args.pop('in_size') out_size = kw_args.pop('out_size') if out_size is None: out_size = in_size * partition out_dim, in_dim = out_size , in_size else: out_dim, in_dim = lin.weight.shape original_cls = type(lin) new_layer = LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args, original_obj=lin) device = lin.weight.device del lin return new_layer.to(device) def merge_linear_lora(lin): if lin.original.weight.data.dtype is not torch.uint8: weight = lin.original.weight out_dim, in_dim = weight.shape new_lin = nn.Linear(in_dim, out_dim, dtype=weight.data.dtype, bias=lin.original.bias is not None) else: import bitsandbytes.functional as F weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(lin.original.bias.data.dtype) out_dim, in_dim = weight.shape new_lin = HackLinearNF4(in_dim, out_dim, bias=lin.original.bias is not None) if lin.original.bias is not None: new_lin.bias.data = lin.original.bias.data new_qkv = [] for mA, mB in zip(lin.matrix_A, lin.matrix_B): new_qkv.append(mB.data.float() @ mA.data.float() * lin.scaling) new_qkv = torch.cat(new_qkv, -2) guess_type = lin.original.bias.data.dtype if lin.original.bias is not None else lin.original.weight.data.dtype if guess_type is torch.uint8: guess_type = torch.float32 new_lin.weight.data = (weight + new_qkv).to(guess_type) return new_lin.cuda() if torch.cuda.is_available() else new_lin class LoraMixin(BaseMixin): def __init__(self, layer_num, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., layer_range = None, qlora = False, cross_attention = True): super().__init__() self.r = r self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout if layer_range is None: layer_range = [i for i in range(layer_num)] self.layer_range = layer_range self.scaling = self.lora_alpha / self.r self.qlora = qlora self.cross_attention = cross_attention def reinit(self, parent_model): for i in self.layer_range: print_rank0(f'replacing layer {i} attention with lora') parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=None) parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(parent_model.transformer.layers[i].attention.query_key_value, parent_model.transformer.layers[i].attention.stride, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=None if not parent_model.transformer.num_multi_query_heads else parent_model.transformer.layers[i].attention.inner_hidden_size + parent_model.transformer.layers[i].attention.hidden_size_per_attention_head * parent_model.transformer.layers[i].attention.num_multi_query_heads * 2) if self.cross_attention and parent_model.transformer.layers[i].is_decoder: print_rank0(f'replacing layer {i} cross attention with lora') kv_size = parent_model.transformer.layers[i].cross_attention.inner_hidden_size * 2 if not parent_model.transformer.cross_num_multi_query_heads else parent_model.transformer.layers[i].cross_attention.hidden_size_per_attention_head * parent_model.transformer.layers[i].cross_attention.cross_num_multi_query_heads * 2 parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.layers[i].cross_attention.inner_hidden_size, out_size=parent_model.transformer.hidden_size) parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.hidden_size, out_size=parent_model.transformer.layers[i].cross_attention.inner_hidden_size) parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, self.lora_dropout, qlora=self.qlora, in_size=parent_model.transformer.layers[i].cross_attention.cross_attn_hidden_size, out_size=kv_size) if self.qlora: print_rank0('replacing chatglm linear layer with 4bit') def replace_linear_with_nf4(model, name=None, cache={}): if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear): out_dim, in_dim = model.weight.shape bias = model.bias is not None new_linear = HackLinearNF4(in_dim, out_dim, bias=bias) new_linear.weight.data.copy_(model.weight.data.detach().clone()) if bias: new_linear.bias.data.copy_(model.bias.data.detach().clone()) return new_linear names = set() for name, child in model.named_children(): if name not in names: if child in cache: new_child = cache[child] else: new_child = replace_linear_with_nf4(child, name=name, cache=cache) cache[child] = new_child setattr(model, name, new_child) names.add(name) flag = True while flag: flag = False for name, child in model.named_children(): if name not in names: setattr(model, name, cache[child]) names.add(name) flag = True return model replace_linear_with_nf4(parent_model.transformer, None, {}) def merge_lora(self): for i in self.layer_range: print_rank0(f'merge layer {i} lora attention back to linear') self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense) self.transformer.layers[i].attention.query_key_value = merge_linear_lora(self.transformer.layers[i].attention.query_key_value) if self.cross_attention and self.transformer.layers[i].is_decoder: print_rank0(f'merge layer {i} lora cross attention back to linear') self.transformer.layers[i].cross_attention.dense = merge_linear_lora(self.transformer.layers[i].cross_attention.dense) self.transformer.layers[i].cross_attention.query = merge_linear_lora(self.transformer.layers[i].cross_attention.query) self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(self.transformer.layers[i].cross_attention.key_value)