def factorize_linear_layer()

in opacus_lab/models/GPT2/model/transformer.py [0:0]


def factorize_linear_layer(LinearLayer, rank):
    U, S, V = torch.svd(LinearLayer.weight)
    lr_S = S[:rank]
    lr_U = U[:, 0:rank]
    lr_V = V.t()[:rank]
    out_features = lr_U.shape[0]
    in_features = lr_V.shape[1]
    bias = LinearLayer.bias is not None
    lr_LinearLayer = FactorizedLinear(in_features, out_features, rank, bias=bias)
    lr_LinearLayer.R.weight = nn.Parameter(torch.sqrt(lr_S).diag() @ lr_V)
    lr_LinearLayer.L.weight = nn.Parameter(lr_U @ torch.sqrt(lr_S).diag())
    if bias:
        lr_LinearLayer.L.bias = nn.Parameter(LinearLayer.bias)
    return lr_LinearLayer