in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def save_peft_model(self, checkpoint_dir):
"""
Save the FSDP wrapped PEFT model to the specified directory. Note that this method will
only save the adapter weights.
"""
logging.info(f"Saving PEFT checkpoint to {checkpoint_dir}")
logging.debug(f"Model to save: {self.pytorch_model}")
def is_peft_adapter(module):
return (
not list(module.named_children())
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
def is_peft_fsdp_wrapper(module):
return hasattr(module, "_fsdp_wrapped_module") and is_peft_adapter(module._fsdp_wrapped_module)
adapter_modules = list(filter(is_peft_fsdp_wrapper, self.pytorch_model.modules()))
context_managers = [
FSDP.summon_full_params(
module,
writeback=False,
rank0_only=True,
offload_to_cpu=True,
)
for module in adapter_modules
]
"""
we don't want to use FSDP FULL state dict because gathering of frozen params
is needlessly expensive and causes OOM issues. we also need to avoid the FSDP
state_dict hooks as they won't return full tensors even with summon_full_params.
so we use SM_LOCAL_STATE_DICT to disable FSDP state_dict hooks.
"""
with ExitStack() as stack, sm_state_dict_type(self.pytorch_model, SMStateDictType.SM_LOCAL_STATE_DICT):
for cm in context_managers:
stack.enter_context(cm)
if dist.get_rank() == 0:
"""
Need to extract the PEFT model from the FSDP wrapper to call save_pretrained()
Example of what the _fsdp_wrapped_module looks like:
FullyShardedDataParallel(
(_fsdp_wrapped_module): PeftModelForCausalLM(
The model needs to be unwrapped in order to extract the PeftModelForCausalLM
"""
self.pytorch_model.module.save_pretrained(checkpoint_dir)
dist.barrier()