in weak_to_strong/model.py [0:0]
def forward(self, input_ids: torch.LongTensor):
"""
Forward pass of the model with a linear head.
Parameters:
input_ids (torch.LongTensor): Input tensor containing the token ids.
Returns:
HeadOutput: Output dataclass containing the logits.
"""
input_lens = (input_ids != 0).sum(dim=-1)
transformer_outputs = self.transformer(input_ids)
hidden_states = torch.stack(
[transformer_outputs[0][i, input_lens[i] - 1, :] for i in range(len(input_lens))]
)
self.score.to(hidden_states.device)
if self.linear_probe:
hidden_states = hidden_states.detach()
logits = self.score(hidden_states)
return logits