in src/sal/models/skywork_o1_prm/prm_model.py [0:0]
def __init__(self, pretrained_model, **kwargs):
r"""
Initializes the model.
Args:
pretrained_model (`transformers.PreTrainedModel`):
The model to wrap. It should be a causal language model such as GPT2.
or any model mapped inside the `AutoModelForCausalLM` class.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
if not any(
hasattr(self.pretrained_model, attribute)
for attribute in self.lm_head_namings
):
raise ValueError(
"The model does not have a language model head, please use a model that has one."
)
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
self._init_weights(**v_head_kwargs)