modules/SwissArmyTransformer/sat/model/official/gpt2_model.py (22 lines of code) (raw):
import torch
import torch.nn as nn
import torch.nn.functional as F
from sat.model.base_model import BaseMixin, BaseModel
import math
from sat import mpu
from transformers.activations import ACT2FN
gelu = ACT2FN["gelu_new"]
class GPT2FinalMixin(BaseMixin):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def final_forward(self, logits, **kwargs):
return self.lm_head(logits)
class GPT2Model(BaseModel):
def __init__(self, args, transformer=None, **kwargs):
super(GPT2Model, self).__init__(args, transformer=transformer, activation_func=gelu, **kwargs)
self.add_mixin("gpt2-final", GPT2FinalMixin(args.vocab_size, args.hidden_size))
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('GPT2', 'GPT2 Configurations')
# group.add_argument('--num-types', type=int)
return parser