in src/sal/models/skywork_o1_prm/prm_model.py [0:0]
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
if hasattr(self.pretrained_model, "hf_device_map"):
if (
"cpu" in self.pretrained_model.hf_device_map.values()
or "disk" in self.pretrained_model.hf_device_map.values()
):
raise ValueError(
"The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models."
)
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
if isinstance(first_device, int):
first_device = f"cuda:{first_device}"
self.v_head = self.v_head.to(first_device)
def set_device_hook(module, input, outputs):
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(first_device),)
else:
new_output += (output,)
return new_output
self.register_forward_hook(set_device_hook)
self.is_sequential_parallel = True