def refactor_head()

in opacus_lab/models/GPT2/refactor.py [0:0]


    def refactor_head(GPT2):
        head = nn.Linear(dim, vocab_size, bias=False)
        ln_head = nn.LayerNorm(dim)
        ln_head.weight = nn.Parameter(GPT2.transformer.ln_f.weight)
        ln_head.bias = nn.Parameter(GPT2.transformer.ln_f.bias)
        head.weight = nn.Parameter(GPT2.lm_head.weight)
        return head, ln_head