def refactor_block()

in opacus_lab/models/GPT2/refactor.py [0:0]


    def refactor_block(GPT2Block):
        # 4X expansion rate is hardcoded below
        Block = TransformerLayer(n_heads, 768, 4)

        # first copy layernorm weights, no refactor needed
        Block.ln_attn.weight = nn.Parameter(GPT2Block.ln_1.weight)
        Block.ln_attn.bias = nn.Parameter(GPT2Block.ln_1.bias)
        Block.ln_ff.weight = nn.Parameter(GPT2Block.ln_2.weight)
        Block.ln_ff.bias = nn.Parameter(GPT2Block.ln_2.bias)

        # next refactor and copy the attention & FC layers
        Block.attn = refactor_attention(GPT2Block.attn)
        Block.ff = refactor_feedforward(GPT2Block.mlp)

        return Block