weak_to_strong/model.py (39 lines of code) (raw):
from dataclasses import dataclass
import torch
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
@dataclass
class HeadOutput:
logits: torch.FloatTensor
class TransformerWithHead(PreTrainedModel):
"""
This class initializes the linear head to zeros
"""
def __init__(self, name, linear_probe=False, **kwargs):
config = AutoConfig.from_pretrained(name, **kwargs)
super().__init__(config)
self.num_labels = config.num_labels
lm = AutoModelForCausalLM.from_pretrained(name, **kwargs)
self.lm = lm
self.transformer = lm.transformer
hidden_size = getattr(config, "n_embd", getattr(config, "hidden_size", None))
self.score = torch.nn.Linear(hidden_size, self.num_labels, bias=False).to(
lm.lm_head.weight.dtype
)
torch.nn.init.normal_(self.score.weight, std=0.0)
self.linear_probe = linear_probe
@classmethod
def from_pretrained(cls, name, **kwargs):
return cls(name, **kwargs)
def gradient_checkpointing_enable(self):
model = self.transformer
(
model if hasattr(model, "save_pretrained") else model.module
).gradient_checkpointing_enable()
def forward(self, input_ids: torch.LongTensor):
"""
Forward pass of the model with a linear head.
Parameters:
input_ids (torch.LongTensor): Input tensor containing the token ids.
Returns:
HeadOutput: Output dataclass containing the logits.
"""
input_lens = (input_ids != 0).sum(dim=-1)
transformer_outputs = self.transformer(input_ids)
hidden_states = torch.stack(
[transformer_outputs[0][i, input_lens[i] - 1, :] for i in range(len(input_lens))]
)
self.score.to(hidden_states.device)
if self.linear_probe:
hidden_states = hidden_states.detach()
logits = self.score(hidden_states)
return logits