def save_peft_model()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def save_peft_model(self, checkpoint_dir):
        """
        Save the FSDP wrapped PEFT model to the specified directory. Note that this method will
        only save the adapter weights.
        """
        logging.info(f"Saving PEFT checkpoint to {checkpoint_dir}")
        logging.debug(f"Model to save: {self.pytorch_model}")

        def is_peft_adapter(module):
            return (
                not list(module.named_children())
                and getattr(module, "weight", None) is not None
                and module.weight.requires_grad
            )

        def is_peft_fsdp_wrapper(module):
            return hasattr(module, "_fsdp_wrapped_module") and is_peft_adapter(module._fsdp_wrapped_module)

        adapter_modules = list(filter(is_peft_fsdp_wrapper, self.pytorch_model.modules()))
        context_managers = [
            FSDP.summon_full_params(
                module,
                writeback=False,
                rank0_only=True,
                offload_to_cpu=True,
            )
            for module in adapter_modules
        ]

        """
        we don't want to use FSDP FULL state dict because gathering of frozen params
        is needlessly expensive and causes OOM issues. we also need to avoid the FSDP
        state_dict hooks as they won't return full tensors even with summon_full_params.
        so we use SM_LOCAL_STATE_DICT to disable FSDP state_dict hooks.
        """
        with ExitStack() as stack, sm_state_dict_type(self.pytorch_model, SMStateDictType.SM_LOCAL_STATE_DICT):
            for cm in context_managers:
                stack.enter_context(cm)
            if dist.get_rank() == 0:
                """
                Need to extract the PEFT model from the FSDP wrapper to call save_pretrained()

                Example of what the _fsdp_wrapped_module looks like:
                    FullyShardedDataParallel(
                        (_fsdp_wrapped_module): PeftModelForCausalLM(

                The model needs to be unwrapped in order to extract the PeftModelForCausalLM
                """
                self.pytorch_model.module.save_pretrained(checkpoint_dir)
            dist.barrier()