src/hyperpod_nemo_adapter/scripts/merge_peft_checkpoint.py (66 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import argparse
import os
from peft import PeftModel
from transformers import AutoModelForCausalLM
from hyperpod_nemo_adapter.collections.model.nlp.custom_models.configuration_deepseek import (
DeepseekV3Config,
)
from hyperpod_nemo_adapter.collections.model.nlp.custom_models.modeling_deepseek import (
DeepseekV3ForCausalLM,
)
def run(args):
print("Loading the HF model...")
if args.deepseek_v3:
model_config = DeepseekV3Config.from_pretrained(
args.hf_model_name_or_path, token=args.hf_access_token, trust_remote_code=True
)
if hasattr(model_config, "quantization_config"):
delattr(model_config, "quantization_config")
model = DeepseekV3ForCausalLM.from_pretrained(
args.hf_model_name_or_path,
torch_dtype="auto",
device_map="auto",
token=args.hf_access_token,
config=model_config,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_model_name_or_path,
torch_dtype="auto",
device_map="auto",
token=args.hf_access_token,
)
print("Loading the PEFT adapter checkpoint...")
model = PeftModel.from_pretrained(model, args.peft_adapter_checkpoint_path)
print("Merging the PEFT adapter with the base model...")
model = model.merge_and_unload(progressbar=True)
print(f"Saving the merged model to {args.output_model_path}...")
if not os.path.exists(args.output_model_path):
os.makedirs(args.output_model_path)
model.save_pretrained(args.output_model_path)
print("Model saved successfully.")
def main():
parser = argparse.ArgumentParser(
description="Script for merging a Hugging Face model with a PEFT adapter checkpoint"
)
parser.add_argument(
"--hf_model_name_or_path",
type=str,
required=True,
help="The Hugging Face model name or path to load the model from.",
)
parser.add_argument(
"--peft_adapter_checkpoint_path", type=str, required=True, help="Path to the PEFT adapter checkpoint."
)
parser.add_argument(
"--output_model_path", type=str, required=True, help="Path where the merged model will be saved."
)
parser.add_argument(
"--hf_access_token", type=str, default=None, help="Optional Hugging Face access token for authentication."
)
parser.add_argument("--deepseek_v3", type=bool, default=False, help="Whether the model is DeepSeek V3 model.")
args = parser.parse_args()
run(args)
if __name__ == "__main__":
main()