in weak_to_strong/model.py [0:0]
def __init__(self, name, linear_probe=False, **kwargs):
config = AutoConfig.from_pretrained(name, **kwargs)
super().__init__(config)
self.num_labels = config.num_labels
lm = AutoModelForCausalLM.from_pretrained(name, **kwargs)
self.lm = lm
self.transformer = lm.transformer
hidden_size = getattr(config, "n_embd", getattr(config, "hidden_size", None))
self.score = torch.nn.Linear(hidden_size, self.num_labels, bias=False).to(
lm.lm_head.weight.dtype
)
torch.nn.init.normal_(self.score.weight, std=0.0)
self.linear_probe = linear_probe