maga_transformer/models/gpt_util/prefix_encoder.py (28 lines of code) (raw):
import torch
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config: GptInitModelParameters):
super().__init__()
self.config = config
self.prefix_projection = config.prefix_projection
hidden_size = config.head_num * config.size_per_head
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(hidden_size, config.layer_num * hidden_size * 2)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.layer_num * config.size_per_head * config.head_num_kv * 2)
# input shape: [batch_size, pre_seq_len]
# output shape: [batch_size, layer_num * 2, head_num, pre_seq_len, size_per_head]
def forward(self, prefix: torch.Tensor):
batch_size = prefix.size(0)
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
past_key_values = past_key_values.view(batch_size, self.config.pre_seq_len, self.config.layer_num * 2,
self.config.head_num_kv, self.config.size_per_head)
past_key_values = past_key_values.permute(0, 2, 3, 1, 4).contiguous()
return past_key_values