def convert_layer_to_lora()

in chatlearn/models/megatron/lora/layers.py [0:0]


def convert_layer_to_lora(model,
                          part_module_name=None,
                          lora_dim=None,
                          lora_scaling=None,
                          lora_dropout=None,
                          lora_layer=None,
                          column_only_qkv=None):
    if is_initialized():
        default_args = get_runtime_args().active_module_args.lora
    else:
        default_args = LoraConfig

    part_module_name = part_module_name if part_module_name is not None else default_args.part_module_name
    lora_dim = lora_dim if lora_dim is not None else default_args.lora_dim
    lora_scaling = lora_scaling if lora_scaling is not None else default_args.lora_scaling
    lora_dropout = lora_dropout if lora_dropout is not None else default_args.lora_dropout
    layers_to_convert = lora_layer if lora_layer is not None else default_args.lora_layer
    column_only_qkv = column_only_qkv if column_only_qkv is not None else default_args.column_only_qkv

    if lora_dim <= 0:
        return model

    layers_to_convert = layers_to_convert.split(",")
    assert all(layer in LORA_LAYER_MAP for layer in layers_to_convert), \
        "Unsupport layer to enable lora, {}. Only support {} for now.".format(layers_to_convert, ALL_LORA_LAYER)

    MegatronOptimizer.allreduce_word_embedding_grads = MegatronOptimizer_LoRA.allreduce_word_embedding_grads

    repalce_name = {}
    for name, module in model.named_modules():
        if part_module_name is not None and part_module_name not in name:
            continue
        if isinstance(module, nn.Linear) and "LinearLayer" in layers_to_convert:
            repalce_name[name] = LinearLayer_LoRA
        elif isinstance(module, RowParallelLinear) and "RowParallelLinear" in layers_to_convert:
            repalce_name[name] = RowParallelLinear_LoRA
        elif isinstance(module, ColumnParallelLinear) and "ColumnParallelLinear" in layers_to_convert:
            if column_only_qkv and any(ele not in name for ele in QKV_LAYER_NAME):
                continue
            repalce_name[name] = ColumnParallelLinear_LoRA
        elif isinstance(module, VocabParallelEmbedding) and "VocabParallelEmbedding" in layers_to_convert:
            repalce_name[name] = VocabParallelEmbedding_LoRA
        elif isinstance(module, Embedding) and "Embedding" in layers_to_convert:
            repalce_name[name] = Embedding_LoRA
        else:
            pass

    for name, func in repalce_name.items():
        module = recursive_getattr(model, name)
        kwargs = {}
        if hasattr(module, "input_is_parallel"):
            kwargs["input_is_parallel"] = module.input_is_parallel
        if hasattr(module, "skip_bias_add"):
            kwargs["skip_bias_add"] = module.skip_bias_add
        if hasattr(module, "gather_output"):
            kwargs["gather_output"] = module.gather_output
        if hasattr(module, "input_size"):
            kwargs["input_size"] = module.input_size
        if hasattr(module, "output_size"):
            kwargs["output_size"] = module.output_size
        if hasattr(module, "padding_idx"):
            kwargs["padding_idx"] = module.padding_idx
        if hasattr(module, "max_norm"):
            kwargs["max_norm"] = module.max_norm
        if hasattr(module, "norm_type"):
            kwargs["norm_type"] = module.norm_type
        if hasattr(module, "scale_grad_by_freq"):
            kwargs["scale_grad_by_freq"] = module.scale_grad_by_freq
        if hasattr(module, "sparse"):
            kwargs["sparse"] = module.sparse
        if hasattr(module, "num_embeddings"):
            kwargs["num_embeddings"] = module.num_embeddings
        tmp = func(
            module.weight, lora_dim, lora_scaling, lora_dropout,
            module.bias if hasattr(module, "bias") else None, **kwargs).to(module.weight.device).to(module.weight.dtype)
        recursive_setattr(model, name, tmp)

    only_optimize_lora_parameters(model)

    return model