def optimize_L_sk_gpu()

in src/sk_utils.py [0:0]


def optimize_L_sk_gpu(args, PS, hc, logger=None):
    print('doing optimization now',flush=True)

    # create L
    N = PS.size(0)
    K = PS.size(1)
    tt = time.time()
    _K_dist = torch.ones((K, 1), dtype=torch.float64, device='cuda') # / K
    if args.distribution != 'default':
        marginals_argsort = torch.argsort(PS.sum(0))
        if (args.dist is None) or args.diff_dist_every:
            if args.distribution == 'gauss':
                if args.diff_dist_per_head:
                    _K_dists = [(torch.randn(size=(K, 1), dtype=torch.float64, device='cuda')*args.gauss_sd + 1) * N / K
                                for _ in range(args.headcount)]
                    args.dist = _K_dists
                    _K_dist = _K_dists[hc]
                else:
                    _K_dist = (torch.randn(size=(K, 1), dtype=torch.float64, device='cuda')*args.gauss_sd + 1) * N / K
                    _K_dist = torch.clamp(_K_dist, min=1)
                    args.dist = _K_dist
            if args.rank == 0:
                logger.info(f"distribution used: {_K_dist}")

        else:
            if args.diff_dist_per_head:
                _K_dist = args.dist[hc]
            else:
                _K_dist = args.dist
        _K_dist[marginals_argsort] = torch.sort(_K_dist)[0]

    beta = torch.ones((N, 1), dtype=torch.float64, device='cuda') / N
    PS.pow_(0.5*args.lamb)
    r = 1./_K_dist
    r /= r.sum()

    c = 1./N
    err = 1e6
    _counter = 0

    ones = torch.ones(N, device='cuda:0', dtype=torch.float64)
    while (err > 1e-1) and (_counter < 2000):
        alpha = r / torch.matmul(beta.t(), PS).t()
        beta_new = c / torch.matmul(PS, alpha)
        if _counter % 10 == 0:
            err = torch.sum(torch.abs((beta.squeeze() / beta_new.squeeze()) - ones)).cpu().item()
        beta = beta_new
        _counter += 1
    if args.rank == 0:
        logger.info(f"error: {err}, step : {_counter}")

    # inplace calculations
    torch.mul(PS, beta, out=PS)
    torch.mul(alpha.t(), PS, out=PS)
    newL = torch.argmax(PS, 1).cuda()

    # return back to obtain cost (optional)
    torch.mul((1./alpha).t(), PS, out=PS)
    torch.mul(PS, 1./beta, out=PS)
    sol = np.nansum(torch.log(PS[torch.arange(0, len(newL)).long(), newL]).cpu().numpy())
    cost = -(1. / args.lamb) * sol / N
    if args.rank == 0:
        logger.info(f"opt took {(time.time() - tt) / 60.} min, {_counter} iters")
    return cost, newL