in src/sal/models/skywork_o1_prm/prm_model.py [0:0]
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
if not self.is_peft_model:
pretrained_model_state_dict = self.pretrained_model.state_dict(
*args, **kwargs
)
else:
# if it is a peft model, only save the v_head
pretrained_model_state_dict = {}
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
for k, v in v_head_state_dict.items():
pretrained_model_state_dict[f"v_head.{k}"] = v
return pretrained_model_state_dict