def run()

in src/hyperpod_nemo_adapter/scripts/merge_peft_checkpoint.py [0:0]


def run(args):
    print("Loading the HF model...")

    if args.deepseek_v3:
        model_config = DeepseekV3Config.from_pretrained(
            args.hf_model_name_or_path, token=args.hf_access_token, trust_remote_code=True
        )
        if hasattr(model_config, "quantization_config"):
            delattr(model_config, "quantization_config")
        model = DeepseekV3ForCausalLM.from_pretrained(
            args.hf_model_name_or_path,
            torch_dtype="auto",
            device_map="auto",
            token=args.hf_access_token,
            config=model_config,
            trust_remote_code=True,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.hf_model_name_or_path,
            torch_dtype="auto",
            device_map="auto",
            token=args.hf_access_token,
        )

    print("Loading the PEFT adapter checkpoint...")
    model = PeftModel.from_pretrained(model, args.peft_adapter_checkpoint_path)

    print("Merging the PEFT adapter with the base model...")
    model = model.merge_and_unload(progressbar=True)

    print(f"Saving the merged model to {args.output_model_path}...")
    if not os.path.exists(args.output_model_path):
        os.makedirs(args.output_model_path)
    model.save_pretrained(args.output_model_path)
    print("Model saved successfully.")