in trl/models/modeling_value_head.py [0:0]
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the
key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state
dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
if hasattr(self.pretrained_model, "hf_device_map"):
if (
"cpu" in self.pretrained_model.hf_device_map.values()
or "disk" in self.pretrained_model.hf_device_map.values()
):
raise ValueError(
"The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models."
)
# get the lm_head device
for name, module in self.pretrained_model.named_modules():
if any(attribute in name for attribute in self.lm_head_namings):
lm_head_device = module.weight.device
break
# put v_head on the same device as the lm_head to avoid issues
self.v_head = self.v_head.to(lm_head_device)
def set_device_hook(module, input, outputs):
r"""
A hook that sets the device of the output of the model to the device of the first parameter of the
model.
Args:
module (`nn.Module`):
The module to which the hook is attached.
input (`tuple`):
The input to the module.
outputs (`tuple`):
The output of the module.
"""
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(lm_head_device),)
else:
new_output += (output,)
return new_output
self.register_forward_hook(set_device_hook)
self.is_sequential_parallel = True