in src/peft/peft_model.py [0:0]
def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_weights: dict[str, torch.tensor]):
"""
Update the offload_index and safetensors files for loading and mergine PeftModels with disk-offloaded modules.
Args:
offload_index (Dict[str: str]):
Dictionary of disk-offloaded modules with their metadata and safetensors filenames
adapters_weights (Dict[str: torch.tensor]):
Dictionary of Peft adapter module names and weights
"""
if not offload_index:
return offload_index
prefix = "base_model.model."
# rename offload index weight and model names
adapter_names = list(self.peft_config.keys())
for adapter_name in adapter_names:
keys = list(offload_index.keys())
block_id = keys[0].split(".")[0] + "." # for writing safetensors key,
# replace original offload index keys with PeftModel keys
for key in keys:
suffix_pos = key.rfind(".")
extended_prefix = prefix + key[:suffix_pos]
module = dict(self.named_modules())[extended_prefix]
if isinstance(module, BaseTunerLayer):
new_key = prefix + key[:suffix_pos] + ".base_layer" + key[suffix_pos:]
else:
new_key = prefix + key
offload_index[key]["weight_name"] = new_key
offload_index[new_key] = offload_index[key]
del offload_index[key]
files_seen = set()
# rename safetensors for dispatch
for new_key in list(offload_index.keys()):
fname = offload_index[new_key]["safetensors_file"]
# make a new file name
new_fname_list = list(fname.split(os.sep))
for i, name in enumerate(new_fname_list):
if "--" in name:
new_fname_list[i] += "-peft"
break
new_fname = os.path.join(*new_fname_list)
if fname in files_seen:
continue
safe_dict = {}
with safe_open(fname, framework="pt") as f:
for safe_key in f.keys():
safe_tensor = f.get_tensor(safe_key)
metadata = f.metadata()
suffix_pos = safe_key.rfind(".")
extended_prefix = prefix + block_id + safe_key[:suffix_pos]
safe_module = dict(self.named_modules())[extended_prefix]
if isinstance(safe_module, BaseTunerLayer):
final_key = extended_prefix + ".base_layer" + safe_key[suffix_pos:]
lora_dict = {key: val for key, val in adapters_weights.items() if extended_prefix in key}
# add LoRA keys and values to disk offload
for lora_key, lora_val in lora_dict.items():
divide = lora_key.rfind(".")
new_key = lora_key[:divide] + f".{adapter_name}" + lora_key[divide:]
safe_dict[new_key] = lora_val
else:
final_key = prefix + block_id + safe_key
safe_dict[final_key] = safe_tensor
files_seen.add(new_fname)
# avoid overwriting original safetensors
for key in safe_dict.keys():
offload_index[key] = {"safetensors_file": new_fname, "weight_name": key}
base_name = os.path.dirname(new_fname)
if not os.path.exists(base_name):
os.makedirs(base_name)
safe_save_file(safe_dict, new_fname, metadata=metadata)