def refactor_attention()

in opacus_lab/models/GPT2/refactor.py [0:0]


    def refactor_attention(GPT2Attention):
        Conv1D = GPT2Attention.c_attn
        Proj = GPT2Attention.c_proj

        Attention = AttentionLayer(n_heads, dim, 0.1)
        Attention.linear.weight = nn.Parameter(Proj.weight.t())
        Attention.linear.bias = nn.Parameter(Proj.bias)

        q_weight, k_weight, v_weight = torch.split(Conv1D.weight, [dim] * 3, dim=-1)
        q_bias, k_bias, v_bias = torch.split(Conv1D.bias, [dim] * 3, dim=-1)

        Attention.proj_q.weight = nn.Parameter(q_weight.t())
        Attention.proj_k.weight = nn.Parameter(k_weight.t())
        Attention.proj_v.weight = nn.Parameter(v_weight.t())

        Attention.proj_q.bias = nn.Parameter(q_bias)
        Attention.proj_k.bias = nn.Parameter(k_bias)
        Attention.proj_v.bias = nn.Parameter(v_bias)

        return Attention