in opacus_lab/models/GPT2/refactor.py [0:0]
def refactor_embeddings(GPT2):
# num of pos emb hardcoded below
wpe = PositionalEmbedding(1024, dim)
wte = TokenEmbedding(vocab_size, dim)
wte.emb.weight = nn.Parameter(GPT2.transformer.wte.weight)
wpe.emb.weight = nn.Parameter(GPT2.transformer.wpe.weight)
return wpe, wte