chatlearn/models/megatron/lora/layers.py (586 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """lora layers.""" import math import importlib.util import torch import torch.nn.functional as F from torch import nn from torch.nn import Embedding from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from chatlearn.models.megatron.lora.initializer import distributed_kaiming_uniform_ from chatlearn.models.megatron.lora.utils import recursive_getattr, recursive_setattr from chatlearn.utils.arguments import LoraConfig from chatlearn.utils.constant import LORA_WEIGHT_PREFIX from chatlearn.utils.constant import QKV_LAYER_NAME from chatlearn.utils.global_vars import get_args as get_runtime_args from chatlearn.utils.global_vars import is_initialized megatron_exist = importlib.util.find_spec("megatron") if megatron_exist: from chatlearn.utils.megatron_import_helper import get_args from chatlearn.utils.megatron_import_helper import mpu from chatlearn.utils.megatron_import_helper import Float16Module from chatlearn.utils.megatron_import_helper import MegatronOptimizer from chatlearn.utils.megatron_import_helper import unwrap_model from chatlearn.utils.megatron_import_helper import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size ) from chatlearn.utils.megatron_import_helper import ( # pylint: disable=unused-import ColumnParallelLinear, linear_with_frozen_weight, linear_with_grad_accumulation_and_async_allreduce, LinearWithGradAccumulationAndAsyncCommunication, RowParallelLinear, VocabParallelEmbedding ) from chatlearn.utils.megatron_import_helper import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region ) from chatlearn.utils.megatron_import_helper import VocabUtility class LoraBase(torch.nn.Module): # pylint: disable=abstract-method """Lora Base""" def state_dict(self, *args, destination=None, prefix='', keep_vars=False): sd = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) if self.fuse_lora: sd = {key: value for key, value in sd.items() if not key.startswith(LORA_WEIGHT_PREFIX)} return sd class ColumnParallelLinear_LoRA(LoraBase): """LoRA version of megatron.core.tensor_parallel.layers.ColumnParallelLinear. Arguments: weight: weight of original ColumnParallelLinear module. lora_dim: lora rank dim. lora_scaling: lora scaling value. bias: bias of original ColumnParallelLinear module. kwargs: args of original ColumnParallelLinear module. """ def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_dropout=0, bias=None, **kwargs): super().__init__() # Keep input parameters self.gather_output = kwargs.get("gather_output", True) # Divide the weight matrix along the last dimension. world_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the last dimension. self.skip_bias_add = kwargs.get("skip_bias_add", False) # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. args = get_args() self.async_tensor_model_parallel_allreduce = ( args.async_tensor_model_parallel_allreduce and world_size > 1) self.sequence_parallel = ( args.sequence_parallel and world_size > 1) assert not self.async_tensor_model_parallel_allreduce or \ not self.sequence_parallel self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.weight = weight self.bias = bias if lora_dim <= 0: raise ValueError( "You are training to use LoRA, whose reduced dim should be larger than 1" ) rows, columns = weight.shape self.fan_in = columns self.lora_right_weight = nn.Parameter(torch.zeros( lora_dim, columns )) # apply transpose so in forward we do not need to self.lora_left_weight = nn.Parameter(torch.zeros(rows, lora_dim)) self.lora_scaling = lora_scaling / lora_dim if lora_dropout > 0: self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() self.reset_parameters() # disable the original weight gradient self.weight.requires_grad = False # fuse LoRA to the original weight self.fuse_lora = False def eval(self): self.lora_dropout.eval() # self.fuse_lora_weight() def train(self, mode=True): self.lora_dropout.train(mode) # self.unfuse_lora_weight() def reset_parameters(self): distributed_kaiming_uniform_(self.lora_right_weight, self.fan_in, a=math.sqrt(5)) nn.init.zeros_(self.lora_left_weight) def fuse_lora_weight(self): if not self.fuse_lora: self.weight.data += self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = True def unfuse_lora_weight(self): if self.fuse_lora: self.weight.data -= self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = False def forward(self, input_): bias = self.bias if not self.skip_bias_add else None if self.async_tensor_model_parallel_allreduce or \ self.sequence_parallel: input_parallel = input_ else: input_parallel = copy_to_tensor_model_parallel_region(input_) # Matrix multiply. output_parallel = linear_with_frozen_weight( input=input_parallel, weight=self.weight, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel=self.sequence_parallel, ) residual = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.lora_right_weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel=self.sequence_parallel, ) residual = linear_with_grad_accumulation_and_async_allreduce( input=residual, weight=self.lora_left_weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel=False, ) residual = self.lora_dropout(residual) output_parallel = output_parallel + self.lora_scaling * residual if self.gather_output: # All-gather across the partitions. assert not self.sequence_parallel output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias class RowParallelLinear_LoRA(LoraBase): """LoRA version of megatron.core.tensor_parallel.layers.RowParallelLinear. Arguments: weight: weight of original RowParallelLinear module. lora_dim: lora rank dim. lora_scaling: lora scaling value. bias: bias of original RowParallelLinear module. kwargs: args of original RowParallelLinear module. """ def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_dropout=0, bias=None, **kwargs): super().__init__() self.input_is_parallel = kwargs.get("input_is_parallel", False) # Divide the weight matrix along the last dimension. self.skip_bias_add = kwargs.get("skip_bias_add", False) # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. # Initialize weight. args = get_args() self.sequence_parallel = args.sequence_parallel self.gradient_accumulation_fusion = args.gradient_accumulation_fusion self.weight = weight self.bias = bias if lora_dim <= 0: raise ValueError( "You are training to use LoRA, whose reduced dim should be larger than 1" ) rows, columns = weight.shape self.fan_in = columns self.lora_right_weight = nn.Parameter(torch.zeros( lora_dim, columns )) # apply transpose so in forward we do not need to self.lora_left_weight = nn.Parameter(torch.zeros(rows, lora_dim)) self.lora_scaling = lora_scaling / lora_dim if lora_dropout > 0: self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() self.reset_parameters() # disable the original weight gradient self.weight.requires_grad = False # fuse LoRA to the original weight self.fuse_lora = False def eval(self): self.lora_dropout.eval() # self.fuse_lora_weight() def train(self, mode=True): self.lora_dropout.train(mode) # self.unfuse_lora_weight() def reset_parameters(self): distributed_kaiming_uniform_(self.lora_right_weight, self.fan_in, a=math.sqrt(5)) nn.init.zeros_(self.lora_left_weight) def fuse_lora_weight(self): if not self.fuse_lora: self.weight.data += self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = True def unfuse_lora_weight(self): if self.fuse_lora: self.weight.data -= self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = False def forward(self, input_): # Set up backprop all-reduce. if self.input_is_parallel: input_parallel = input_ else: assert not self.sequence_parallel input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. output_parallel = linear_with_frozen_weight( input=input_parallel, weight=self.weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, ) residual = linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=self.lora_right_weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, ) residual = linear_with_grad_accumulation_and_async_allreduce( input=residual, weight=self.lora_left_weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, async_grad_allreduce=False, sequence_parallel=False, ) residual = self.lora_dropout(residual) output_parallel = output_parallel + self.lora_scaling * residual # All-reduce across all the partitions. if self.sequence_parallel: output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) else: output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ output_bias = None else: output = output_ output_bias = self.bias return output, output_bias class LinearLayer_LoRA(LoraBase): """LoRA version of torch.nn.Linear. Arguments: weight: weight of original torch.nn.Linear module. lora_dim: lora rank dim. lora_scaling: lora scaling value. bias: bias of original torch.nn.Linear module. kwargs: args of original torch.nn.Linear module. """ def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_dropout=0, bias=None): super().__init__() self.weight = weight self.bias = bias if lora_dim <= 0: raise ValueError( "You are training to use LoRA, whose reduced dim should be larger than 1" ) try: # for zero stage 3 rows, columns = weight.ds_shape except: # pylint: disable=bare-except rows, columns = weight.shape self.fan_in = columns self.lora_right_weight = nn.Parameter(torch.zeros( lora_dim, columns)) # apply transpose so in forward we do not need to self.lora_left_weight = nn.Parameter(torch.zeros(rows, lora_dim)) self.lora_scaling = lora_scaling / lora_dim if lora_dropout > 0: self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() self.reset_parameters() # disable the original weight gradient self.weight.requires_grad = False # fuse LoRA to the original weight self.fuse_lora = False def eval(self): self.lora_dropout.eval() # self.fuse_lora_weight() def train(self, mode=True): self.lora_dropout.train(mode) # self.unfuse_lora_weight() def reset_parameters(self): distributed_kaiming_uniform_(self.lora_right_weight, self.fan_in, a=math.sqrt(5)) nn.init.zeros_(self.lora_left_weight) def fuse_lora_weight(self): if not self.fuse_lora: self.weight.data += self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = True def unfuse_lora_weight(self): if self.fuse_lora: self.weight.data -= self.lora_scaling * torch.matmul( self.lora_left_weight, self.lora_right_weight) self.fuse_lora = False def forward(self, inputs): if self.fuse_lora: return F.linear(inputs, self.weight, self.bias) else: return F.linear( inputs, self.weight, self.bias) + (self.lora_dropout(inputs) @ self.lora_right_weight.t() @ self.lora_left_weight.t()) * self.lora_scaling class VocabParallelEmbedding_LoRA(LoraBase): """LoRA version of megatron.core.tensor_parallel.layers.VocabParallelEmbedding. Arguments: weight: weight of original VocabParallelEmbedding module. lora_dim: lora rank dim. lora_scaling: lora scaling value. bias: bias of original VocabParallelEmbedding module. kwargs: args of original VocabParallelEmbedding module. """ def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_dropout=0, bias=None, **kwargs): super().__init__() # Set the detauls for compatibility. self.padding_idx = kwargs.get("padding_idx", None) self.max_norm = kwargs.get("max_norm", None) self.norm_type = kwargs.get("norm_type", 2.) self.scale_grad_by_freq = kwargs.get("scale_grad_by_freq", False) self.sparse = kwargs.get("sparse", False) self.num_embeddings = kwargs.get("num_embeddings") self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size) # Allocate weights and initialize. self.weight = weight self.bias = bias if lora_dim <= 0: raise ValueError( "You are training to use LoRA, whose reduced dim should be larger than 1" ) rows, columns = weight.shape self.fan_in = columns self.lora_right_weight = nn.Parameter(torch.zeros( columns, lora_dim)) # apply transpose so in forward we do not need to self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) self.lora_scaling = lora_scaling / lora_dim if lora_dropout > 0: self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() self.reset_parameters() # disable the original weight gradient self.weight.requires_grad = False # fuse LoRA to the original weight self.fuse_lora = False def eval(self): self.lora_dropout.eval() # self.fuse_lora_weight() def train(self, mode=True): self.lora_dropout.train(mode) # self.unfuse_lora_weight() def reset_parameters(self): distributed_kaiming_uniform_(self.lora_right_weight, self.fan_in, a=math.sqrt(5)) nn.init.zeros_(self.lora_left_weight) def fuse_lora_weight(self): if not self.fuse_lora: self.weight.data += self.lora_scaling * torch.matmul( self.lora_left_weight.t(), self.lora_right_weight.t()) self.fuse_lora = True def unfuse_lora_weight(self): if self.fuse_lora: self.weight.data -= self.lora_scaling * torch.matmul( self.lora_left_weight.t(), self.lora_right_weight.t()) self.fuse_lora = False def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) if not self.fuse_lora: after_A = F.embedding( masked_input, self.lora_left_weight.T, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse ) output_parallel += (after_A @ self.lora_right_weight.T) * self.lora_scaling # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) return output class Embedding_LoRA(LoraBase): """LoRA version of torch.nn.Embedding. Arguments: weight: weight of original torch.nn.Embedding module. lora_dim: lora rank dim. lora_scaling: lora scaling value. bias: bias of original torch.nn.Embedding module. kwargs: args of original torch.nn.Embedding module. """ def __init__(self, weight, lora_dim=0, lora_scaling=1, lora_dropout=0, bias=None, **kwargs): super().__init__() self.padding_idx = kwargs.get("padding_idx", None) self.max_norm = kwargs.get("max_norm", None) self.norm_type = kwargs.get("norm_type", 2.) self.scale_grad_by_freq = kwargs.get("scale_grad_by_freq", False) self.sparse = kwargs.get("sparse", False) self.num_embeddings = kwargs.get("num_embeddings") # Set the detauls for compatibility. self.weight = weight self.bias = bias if lora_dim <= 0: raise ValueError( "You are training to use LoRA, whose reduced dim should be larger than 1" ) rows, columns = weight.shape self.fan_in = columns self.lora_right_weight = nn.Parameter(torch.zeros( columns, lora_dim)) # apply transpose so in forward we do not need to self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) self.lora_scaling = lora_scaling / lora_dim if lora_dropout > 0: self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() self.reset_parameters() # disable the original weight gradient self.weight.requires_grad = False self.weight.shared = True # fuse LoRA to the original weight self.fuse_lora = False def eval(self): self.lora_dropout.eval() # self.fuse_lora_weight() def train(self, mode=True): self.lora_dropout.train(mode) # self.unfuse_lora_weight() def reset_parameters(self): distributed_kaiming_uniform_(self.lora_right_weight, self.fan_in, a=math.sqrt(5)) nn.init.zeros_(self.lora_left_weight) def fuse_lora_weight(self): if not self.fuse_lora: self.weight.data += self.lora_scaling * torch.matmul( self.lora_left_weight.t(), self.lora_right_weight.t()) self.fuse_lora = True def unfuse_lora_weight(self): if self.fuse_lora: self.weight.data -= self.lora_scaling * torch.matmul( self.lora_left_weight.t(), self.lora_right_weight.t()) self.fuse_lora = False def forward(self, input_): output = F.embedding( input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) if not self.fuse_lora: after_A = F.embedding( input_, self.lora_left_weight.T, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse ) output += (after_A @ self.lora_right_weight.T) * self.lora_scaling return output class MegatronOptimizer_LoRA(MegatronOptimizer): """ MegatronOptimizer for LoRA """ def allreduce_word_embedding_grads(self, args): """ All-reduce word embedding grads. Reduce grads across first and last stages to ensure that word_embeddings parameters stay in sync. This should only run for models that support pipelined model parallelism (BERT and GPT-2). """ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.is_pipeline_first_stage(ignore_virtual=True): unwrapped_model = self.models[0] elif mpu.is_pipeline_last_stage(ignore_virtual=True): unwrapped_model = self.models[-1] else: # We do not support the interleaved schedule for T5 yet. unwrapped_model = self.models[0] if hasattr(unwrapped_model, "share_word_embeddings"): from chatlearn.utils.megatron_import_helper import DistributedDataParallel as LocalDDP # pylint: disable=import-outside-toplevel unwrapped_model = unwrap_model( unwrapped_model, (torchDDP, LocalDDP, Float16Module)) if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() if word_embeddings_weight.requires_grad: if args.DDP_impl == 'local': grad = word_embeddings_weight.main_grad else: grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) elif hasattr(unwrapped_model, "share_embeddings_and_output_weights"): unwrapped_model = unwrap_model(unwrapped_model) if unwrapped_model.share_embeddings_and_output_weights: weight = unwrapped_model.shared_embedding_or_output_weight() if weight.requires_grad: grad = weight.main_grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) ALL_LORA_LAYER = ( ColumnParallelLinear_LoRA, Embedding_LoRA, LinearLayer_LoRA, RowParallelLinear_LoRA, VocabParallelEmbedding_LoRA ) LORA_LAYER_MAP = { "ColumnParallelLinear": ColumnParallelLinear_LoRA, "Embedding": Embedding_LoRA, "LinearLayer": LinearLayer_LoRA, "RowParallelLinear": RowParallelLinear_LoRA, "VocabParallelEmbedding": VocabParallelEmbedding_LoRA } # convert layer to LoRA def convert_layer_to_lora(model, part_module_name=None, lora_dim=None, lora_scaling=None, lora_dropout=None, lora_layer=None, column_only_qkv=None): if is_initialized(): default_args = get_runtime_args().active_module_args.lora else: default_args = LoraConfig part_module_name = part_module_name if part_module_name is not None else default_args.part_module_name lora_dim = lora_dim if lora_dim is not None else default_args.lora_dim lora_scaling = lora_scaling if lora_scaling is not None else default_args.lora_scaling lora_dropout = lora_dropout if lora_dropout is not None else default_args.lora_dropout layers_to_convert = lora_layer if lora_layer is not None else default_args.lora_layer column_only_qkv = column_only_qkv if column_only_qkv is not None else default_args.column_only_qkv if lora_dim <= 0: return model layers_to_convert = layers_to_convert.split(",") assert all(layer in LORA_LAYER_MAP for layer in layers_to_convert), \ "Unsupport layer to enable lora, {}. Only support {} for now.".format(layers_to_convert, ALL_LORA_LAYER) MegatronOptimizer.allreduce_word_embedding_grads = MegatronOptimizer_LoRA.allreduce_word_embedding_grads repalce_name = {} for name, module in model.named_modules(): if part_module_name is not None and part_module_name not in name: continue if isinstance(module, nn.Linear) and "LinearLayer" in layers_to_convert: repalce_name[name] = LinearLayer_LoRA elif isinstance(module, RowParallelLinear) and "RowParallelLinear" in layers_to_convert: repalce_name[name] = RowParallelLinear_LoRA elif isinstance(module, ColumnParallelLinear) and "ColumnParallelLinear" in layers_to_convert: if column_only_qkv and any(ele not in name for ele in QKV_LAYER_NAME): continue repalce_name[name] = ColumnParallelLinear_LoRA elif isinstance(module, VocabParallelEmbedding) and "VocabParallelEmbedding" in layers_to_convert: repalce_name[name] = VocabParallelEmbedding_LoRA elif isinstance(module, Embedding) and "Embedding" in layers_to_convert: repalce_name[name] = Embedding_LoRA else: pass for name, func in repalce_name.items(): module = recursive_getattr(model, name) kwargs = {} if hasattr(module, "input_is_parallel"): kwargs["input_is_parallel"] = module.input_is_parallel if hasattr(module, "skip_bias_add"): kwargs["skip_bias_add"] = module.skip_bias_add if hasattr(module, "gather_output"): kwargs["gather_output"] = module.gather_output if hasattr(module, "input_size"): kwargs["input_size"] = module.input_size if hasattr(module, "output_size"): kwargs["output_size"] = module.output_size if hasattr(module, "padding_idx"): kwargs["padding_idx"] = module.padding_idx if hasattr(module, "max_norm"): kwargs["max_norm"] = module.max_norm if hasattr(module, "norm_type"): kwargs["norm_type"] = module.norm_type if hasattr(module, "scale_grad_by_freq"): kwargs["scale_grad_by_freq"] = module.scale_grad_by_freq if hasattr(module, "sparse"): kwargs["sparse"] = module.sparse if hasattr(module, "num_embeddings"): kwargs["num_embeddings"] = module.num_embeddings tmp = func( module.weight, lora_dim, lora_scaling, lora_dropout, module.bias if hasattr(module, "bias") else None, **kwargs).to(module.weight.device).to(module.weight.dtype) recursive_setattr(model, name, tmp) only_optimize_lora_parameters(model) return model def fuse_lora_layer(model): if isinstance(model, list): model = model[0] for _, module in model.named_modules(): if isinstance(module, ALL_LORA_LAYER): module.fuse_lora_weight() def unfuse_lora_layer(model): if isinstance(model, list): model = model[0] for _, module in model.named_modules(): if isinstance(module, ALL_LORA_LAYER): module.unfuse_lora_weight() def only_optimize_lora_parameters(model, excluded_flags="bias", excluded_attrs="sequence_parallel", is_training=True): # turn off the gradient of all the parameters except the LoRA parameters excluded_flags = excluded_flags.split(",") excluded_attrs = excluded_attrs.split(",") for name, param in model.named_parameters(): if "lora_right_weight" in name or "lora_left_weight" in name or \ any(getattr(param, ele, False) for ele in excluded_attrs) or \ any(ele in name for ele in excluded_flags): param.requires_grad = is_training else: param.requires_grad = False print_trainable_parameters(model) return model def print_trainable_parameters(model): """ Prints the number of trainable parameters in the model. """ trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() if torch.distributed.get_rank() == 0: print( f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" )