review_sentiment_model.py (18 lines of code) (raw):
import torch
import torch.nn as nn
class ReviewSentimentModel(nn.Module):
def __init__(self, transformer, output_dim, freeze):
super().__init__()
self.transformer = transformer
hidden_dim = transformer.config.hidden_size
self.fc = nn.Linear(hidden_dim, output_dim)
if freeze:
for param in self.transformer.parameters():
param.requires_grad = False
def forward(self, ids):
# ids = [batch size, seq len]
output = self.transformer(ids, output_attentions=True)
hidden = output.last_hidden_state
# hidden = [batch size, seq len, hidden dim]
attention = output.attentions[-1]
# attention = [batch size, n heads, seq len, seq len]
cls_hidden = hidden[:, 0, :]
prediction = self.fc(torch.tanh(cls_hidden))
# prediction = [batch size, output dim]
return prediction