in src/sal/models/skywork_o1_prm/modeling_base.py [0:0]
def save_pretrained(self, *args, **kwargs):
r"""
Save the pretrained model to a directory. This method is a wrapper around
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation
of `transformers.PreTrainedModel.save_pretrained` for more information.
Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`save_pretrained` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`save_pretrained` method.
"""
state_dict = kwargs.get("state_dict")
if state_dict is None:
state_dict = self.state_dict()
kwargs["state_dict"] = state_dict
# if it is a peft model only save the `v_head` state_dict and
# pop the `state_dict` from the kwargs to avoid slient bugs with `peft`
if self.is_peft_model:
save_path = args[0]
save_path = os.path.join(save_path, "pytorch_model.bin")
torch.save(state_dict, save_path)
_ = kwargs.pop("state_dict", None)
return self.pretrained_model.save_pretrained(*args, **kwargs)