in trl/models/modeling_value_head.py [0:0]
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the
key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state
dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
if hasattr(self.pretrained_model, "hf_device_map"):
if (
"cpu" in self.pretrained_model.hf_device_map.values()
or "disk" in self.pretrained_model.hf_device_map.values()
):
raise ValueError(
"The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models."
)
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
if isinstance(first_device, int):
if is_torch_npu_available():
first_device = f"npu:{first_device}"
elif is_torch_xpu_available():
first_device = f"xpu:{first_device}"
else:
first_device = f"cuda:{first_device}"
self.v_head = self.v_head.to(first_device)
def set_device_hook(module, input, outputs):
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(first_device),)
else:
new_output += (output,)
return new_output
self.register_forward_hook(set_device_hook)
self.is_sequential_parallel = True