in chatlearn/models/megatron/lora/layers.py [0:0]
def convert_layer_to_lora(model,
part_module_name=None,
lora_dim=None,
lora_scaling=None,
lora_dropout=None,
lora_layer=None,
column_only_qkv=None):
if is_initialized():
default_args = get_runtime_args().active_module_args.lora
else:
default_args = LoraConfig
part_module_name = part_module_name if part_module_name is not None else default_args.part_module_name
lora_dim = lora_dim if lora_dim is not None else default_args.lora_dim
lora_scaling = lora_scaling if lora_scaling is not None else default_args.lora_scaling
lora_dropout = lora_dropout if lora_dropout is not None else default_args.lora_dropout
layers_to_convert = lora_layer if lora_layer is not None else default_args.lora_layer
column_only_qkv = column_only_qkv if column_only_qkv is not None else default_args.column_only_qkv
if lora_dim <= 0:
return model
layers_to_convert = layers_to_convert.split(",")
assert all(layer in LORA_LAYER_MAP for layer in layers_to_convert), \
"Unsupport layer to enable lora, {}. Only support {} for now.".format(layers_to_convert, ALL_LORA_LAYER)
MegatronOptimizer.allreduce_word_embedding_grads = MegatronOptimizer_LoRA.allreduce_word_embedding_grads
repalce_name = {}
for name, module in model.named_modules():
if part_module_name is not None and part_module_name not in name:
continue
if isinstance(module, nn.Linear) and "LinearLayer" in layers_to_convert:
repalce_name[name] = LinearLayer_LoRA
elif isinstance(module, RowParallelLinear) and "RowParallelLinear" in layers_to_convert:
repalce_name[name] = RowParallelLinear_LoRA
elif isinstance(module, ColumnParallelLinear) and "ColumnParallelLinear" in layers_to_convert:
if column_only_qkv and any(ele not in name for ele in QKV_LAYER_NAME):
continue
repalce_name[name] = ColumnParallelLinear_LoRA
elif isinstance(module, VocabParallelEmbedding) and "VocabParallelEmbedding" in layers_to_convert:
repalce_name[name] = VocabParallelEmbedding_LoRA
elif isinstance(module, Embedding) and "Embedding" in layers_to_convert:
repalce_name[name] = Embedding_LoRA
else:
pass
for name, func in repalce_name.items():
module = recursive_getattr(model, name)
kwargs = {}
if hasattr(module, "input_is_parallel"):
kwargs["input_is_parallel"] = module.input_is_parallel
if hasattr(module, "skip_bias_add"):
kwargs["skip_bias_add"] = module.skip_bias_add
if hasattr(module, "gather_output"):
kwargs["gather_output"] = module.gather_output
if hasattr(module, "input_size"):
kwargs["input_size"] = module.input_size
if hasattr(module, "output_size"):
kwargs["output_size"] = module.output_size
if hasattr(module, "padding_idx"):
kwargs["padding_idx"] = module.padding_idx
if hasattr(module, "max_norm"):
kwargs["max_norm"] = module.max_norm
if hasattr(module, "norm_type"):
kwargs["norm_type"] = module.norm_type
if hasattr(module, "scale_grad_by_freq"):
kwargs["scale_grad_by_freq"] = module.scale_grad_by_freq
if hasattr(module, "sparse"):
kwargs["sparse"] = module.sparse
if hasattr(module, "num_embeddings"):
kwargs["num_embeddings"] = module.num_embeddings
tmp = func(
module.weight, lora_dim, lora_scaling, lora_dropout,
module.bias if hasattr(module, "bias") else None, **kwargs).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
only_optimize_lora_parameters(model)
return model