optimum/neuron/peft/tuners/lora/layer.py (406 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Union import torch from torch import nn from ....utils.import_utils import is_neuronx_distributed_available, is_peft_available if is_peft_available(): from peft.tuners.lora import Embedding as LoraEmbedding from peft.tuners.lora import Linear as LoraLinear from peft.tuners.lora import LoraLayer from peft.utils.integrations import gather_params_ctx else: class LoraLinear: pass class LoraEmbedding: pass class LoraLayer: pass def gather_params_ctx(param): pass if is_neuronx_distributed_available(): from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear as NxDGQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers.layers import ( BaseParallelLinear, ColumnParallelLinear, RowParallelLinear, ) from neuronx_distributed.parallel_layers.layers import ParallelEmbedding as NxDParallelEmbedding from neuronx_distributed.parallel_layers.mappings import scatter_to_sequence_parallel_region else: class NxDParallelEmbedding: def __init__(self, *args, **kwargs): pass class BaseParallelLinear: def __init__(self, *args, **kwargs): pass class ColumnParallelLinear: def __init__(self, *args, **kwargs): pass class RowParallelLinear: def __init__(self, *args, **kwargs): pass class NxDGQAQKVColumnParallelLinear: def __init__(self, *args, **kwargs): pass def scatter_to_sequence_parallel_region(*args, **kwargs): pass class NeuronLoraLayer(LoraLayer): def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: self.base_layer = base_layer self.r = {} self.lora_alpha = {} self.scaling = {} self.lora_dropout = nn.ModuleDict({}) self.lora_A = nn.ModuleDict({}) self.lora_B = nn.ModuleDict({}) # For Embedding layer self.lora_embedding_A = nn.ParameterDict({}) self.lora_embedding_B = nn.ParameterDict({}) # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] self.use_dora: dict[str, bool] = {} self.lora_bias: dict[str, bool] = {} self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA self._caches: dict[str, Any] = {} self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload self.kwargs = kwargs base_layer = self.get_base_layer() if isinstance(base_layer, NxDGQAQKVColumnParallelLinear): in_features, out_features = base_layer.input_size, base_layer.output_sizes elif isinstance(base_layer, BaseParallelLinear): in_features, out_features = base_layer.input_size, base_layer.output_size elif isinstance(base_layer, nn.Conv2d): raise NotImplementedError("Conv2d is not supported for LoRA with optimum-neuron.") elif isinstance(base_layer, nn.Conv3d): raise NotImplementedError("Conv3d is not supported for LoRA with optimum-neuron.") elif isinstance(base_layer, NxDParallelEmbedding): in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim elif isinstance(base_layer, nn.Conv1d): raise NotImplementedError("Conv1d is not supported for LoRA with optimum-neuron.") else: raise NotImplementedError( f"LoRA is not supported for {base_layer.__class__.__name__} with optimum-neuron." ) self.in_features = in_features self.out_features = out_features def update_layer( self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False, lora_bias: bool = False, ): # This code works for linear layers, override for other layer types if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: lora_dropout_layer = nn.Dropout(p=lora_dropout) else: lora_dropout_layer = nn.Identity() self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) # Actual trainable parameters # There are two cases: # 1. The base linear layer is a RowParallelLinear, then: # - The lora A matrix needs to be a RowParallelLinear as well, # - The lora B matrix does not need to be parallelized. # 2. The base linear layer is a ColumnParallelLinear, then: # - The lora A matrix does not need to be parallelized, # - The lora B matrix needs to be a ColumnParallelLinear as well. if isinstance(self.base_layer, RowParallelLinear): self.lora_A[adapter_name] = RowParallelLinear( self.in_features, r, bias=False, input_is_parallel=self.base_layer.input_is_parallel, sequence_parallel_enabled=self.base_layer.sequence_parallel_enabled, sequence_dimension=self.base_layer.sequence_dimension, ) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) else: self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = ColumnParallelLinear( r, self.out_features, bias=lora_bias, gather_output=self.base_layer.gather_output, sequence_parallel_enabled=self.base_layer.sequence_parallel_enabled, sequence_dimension=self.base_layer.sequence_dimension, ) self.lora_bias[adapter_name] = lora_bias if use_rslora: self.scaling[adapter_name] = lora_alpha / math.sqrt(r) else: self.scaling[adapter_name] = lora_alpha / r # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): with gather_params_ctx(self.get_base_layer().weight): self.pissa_init(adapter_name, init_lora_weights) elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": with gather_params_ctx(self.get_base_layer().weight): self.olora_init(adapter_name) elif init_lora_weights == "loftq": with gather_params_ctx(self.get_base_layer().weight): self.loftq_init(adapter_name) elif init_lora_weights == "eva": nn.init.zeros_(self.lora_B[adapter_name].weight) elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) # call this before dora_init self._move_adapter_to_device_of_base_layer(adapter_name) if use_dora: self.dora_init(adapter_name) self.use_dora[adapter_name] = True else: self.use_dora[adapter_name] = False self.set_adapter(self.active_adapters) class ParallelLinear(nn.Module, NeuronLoraLayer): def __init__( self, base_layer, adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_target_conv_1d_layer: bool = False, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, use_dora: bool = False, lora_bias: bool = False, **kwargs, ) -> None: super().__init__() NeuronLoraLayer.__init__(self, base_layer, **kwargs) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name self.update_layer( adapter_name, r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, init_lora_weights=init_lora_weights, use_rslora=use_rslora, use_dora=use_dora, lora_bias=lora_bias, ) self.is_target_conv_1d_layer = is_target_conv_1d_layer merge = LoraLinear.merge unmerge = LoraLinear.unmerge get_delta_weight = LoraLinear.get_delta_weight forward = LoraLinear.forward def __repr__(self): rep = super().__repr__() return "lora." + rep class GQAQKVColumnParallelLinear(nn.Module, NeuronLoraLayer): def __init__( self, base_layer, adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_target_conv_1d_layer: bool = False, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, use_dora: bool = False, lora_bias: bool = False, **kwargs, ) -> None: super().__init__() NeuronLoraLayer.__init__(self, base_layer, **kwargs) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name self.update_layer( adapter_name, r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, init_lora_weights=init_lora_weights, use_rslora=use_rslora, use_dora=use_dora, lora_bias=lora_bias, ) self.is_target_conv_1d_layer = is_target_conv_1d_layer def update_layer( self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False, lora_bias: bool = False, ): # This code works for linear layers, override for other layer types if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: lora_dropout_layer = nn.Dropout(p=lora_dropout) else: lora_dropout_layer = nn.Identity() self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = NxDGQAQKVColumnParallelLinear( input_size=r, output_sizes=self.out_features, bias=False, gather_output=self.base_layer.gather_output, dtype=self.base_layer.dtype, init_method=self.base_layer.arg_init_method, kv_size_multiplier=self.base_layer.kv_size_multiplier, sequence_parallel_enabled=self.base_layer.sequence_parallel_enabled, fuse_qkv=self.base_layer.fuse_qkv, ) self.lora_bias[adapter_name] = lora_bias if use_rslora: self.scaling[adapter_name] = lora_alpha / math.sqrt(r) else: self.scaling[adapter_name] = lora_alpha / r # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): with gather_params_ctx(self.get_base_layer().weight): self.pissa_init(adapter_name, init_lora_weights) elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": with gather_params_ctx(self.get_base_layer().weight): self.olora_init(adapter_name) elif init_lora_weights == "loftq": with gather_params_ctx(self.get_base_layer().weight): self.loftq_init(adapter_name) elif init_lora_weights == "eva": nn.init.zeros_(self.lora_B[adapter_name].weight) elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) # call this before dora_init self._move_adapter_to_device_of_base_layer(adapter_name) if use_dora: self.dora_init(adapter_name) self.use_dora[adapter_name] = True else: self.use_dora[adapter_name] = False self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name, init_lora_weights): if init_lora_weights is False: return if adapter_name in self.lora_A.keys(): if init_lora_weights is True: # initialize A the same way as the default for nn.Linear and B to zero # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) elif init_lora_weights.lower() == "gaussian": nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) else: raise ValueError(f"Unknown initialization {init_lora_weights=}") if self.base_layer.fuse_qkv: nn.init.zeros_(self.lora_B[adapter_name].weight_qkv) if self.lora_bias[adapter_name]: nn.init.zeros_(self.lora_B[adapter_name].bias_qkv) else: nn.init.zeros_(self.lora_B[adapter_name].weight_q) nn.init.zeros_(self.lora_B[adapter_name].weight_k) nn.init.zeros_(self.lora_B[adapter_name].weight_v) if self.lora_bias[adapter_name]: nn.init.zeros_(self.lora_B[adapter_name].bias_q) nn.init.zeros_(self.lora_B[adapter_name].bias_k) nn.init.zeros_(self.lora_B[adapter_name].bias_v) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: previous_dtype = x.dtype output_q, output_k, output_v = self.base_layer(x, *args, **kwargs) if not self.merged: for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] lora_dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) dropout_input = lora_A(lora_dropout(x)) lora_q_output, lora_k_output, lora_v_output = lora_B(dropout_input) output_q += lora_q_output * scaling output_k += lora_k_output * scaling output_v += lora_v_output * scaling return output_q.to(previous_dtype), output_k.to(previous_dtype), output_v.to(previous_dtype) def __repr__(self): rep = super().__repr__() return "lora." + rep class ParallelEmbedding(nn.Module, NeuronLoraLayer): def __init__( self, base_layer: nn.Module, adapter_name: str, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, init_lora_weights: Union[bool, str] = True, use_rslora: bool = False, use_dora: bool = False, lora_bias: bool = False, **kwargs, ) -> None: if lora_bias: # lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter raise ValueError(f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.") super().__init__() NeuronLoraLayer.__init__(self, base_layer) self._active_adapter = adapter_name self.update_layer( adapter_name, r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, init_lora_weights=init_lora_weights, use_rslora=use_rslora, use_dora=use_dora, lora_bias=lora_bias, ) update_layer = LoraEmbedding.update_layer dora_init = LoraEmbedding.dora_init merge = LoraEmbedding.merge unmerge = LoraEmbedding.unmerge get_delta_weight = LoraEmbedding.get_delta_weight _mixed_batch_forward = LoraEmbedding._mixed_batch_forward _embed = LoraEmbedding._embed def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: # TODO: no dtype conversion here, unlike in Linear, is that correct? self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) if self.disable_adapters: if self.merged: self.unmerge() result = self.base_layer(x, *args, **kwargs) elif adapter_names is not None: result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) elif self.merged: result = self.base_layer(x, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) # If sequence parallelism is enabled, we need to scatter the input to the sequence parallel region. sequence_parallel_enabled = self.get_base_layer().sequence_parallel_enabled sequence_dimension = self.get_base_layer().sequence_dim if sequence_dimension is None: sequence_dimension = 0 if sequence_parallel_enabled: if sequence_dimension == 0: x = x.transpose(0, 1).contiguous() x = scatter_to_sequence_parallel_region(x, sequence_dimension=sequence_dimension) torch_result_dtype = result.dtype for active_adapter in self.active_adapters: if active_adapter not in self.lora_embedding_A: continue embedding_A = self.lora_embedding_A[active_adapter].T embedding_B = self.lora_embedding_B[active_adapter].T scaling = self.scaling[active_adapter] if not self.use_dora[active_adapter]: after_A = self._embed(x, embedding_A) result = result + (after_A @ embedding_B) * scaling else: mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter]( x, lora_A=embedding_A, lora_B=embedding_B, scaling=scaling, base_layer=self.get_base_layer(), embed_fn=self._embed, ) result = mag_norm_scale * result + dora_result result = result.to(torch_result_dtype) return result def __repr__(self): rep = super().__repr__() return "lora." + rep NEURON_LORA_MODULES = { NxDParallelEmbedding: ParallelEmbedding, ColumnParallelLinear: ParallelLinear, RowParallelLinear: ParallelLinear, NxDGQAQKVColumnParallelLinear: GQAQKVColumnParallelLinear, }