def refactor_embeddings()

in opacus_lab/models/GPT2/refactor.py [0:0]


    def refactor_embeddings(GPT2):
        # num of pos emb hardcoded below
        wpe = PositionalEmbedding(1024, dim)
        wte = TokenEmbedding(vocab_size, dim)

        wte.emb.weight = nn.Parameter(GPT2.transformer.wte.weight)
        wpe.emb.weight = nn.Parameter(GPT2.transformer.wpe.weight)
        return wpe, wte