review_sentiment_model.py (18 lines of code) (raw):

import torch import torch.nn as nn class ReviewSentimentModel(nn.Module): def __init__(self, transformer, output_dim, freeze): super().__init__() self.transformer = transformer hidden_dim = transformer.config.hidden_size self.fc = nn.Linear(hidden_dim, output_dim) if freeze: for param in self.transformer.parameters(): param.requires_grad = False def forward(self, ids): # ids = [batch size, seq len] output = self.transformer(ids, output_attentions=True) hidden = output.last_hidden_state # hidden = [batch size, seq len, hidden dim] attention = output.attentions[-1] # attention = [batch size, n heads, seq len, seq len] cls_hidden = hidden[:, 0, :] prediction = self.fc(torch.tanh(cls_hidden)) # prediction = [batch size, output dim] return prediction