def lrp_linear_layer()

in opacus_lab/models/GPT2/model/transformer.py [0:0]


def lrp_linear_layer(LinearLayer, rank):
    o, i = LinearLayer.weight.shape
    bias = LinearLayer.bias is not None
    lrp_LinearLayer = LowRankPerturbedLinear(i, o, rank, bias=bias)
    lrp_LinearLayer.core.weight = nn.Parameter(LinearLayer.weight, requires_grad=False)
    if bias:
        lrp_LinearLayer.core.bias = nn.Parameter(LinearLayer.bias, requires_grad=False)
    return lrp_LinearLayer