in src/sk_utils.py [0:0]
def optimize_L_sk_gpu(args, PS, hc, logger=None):
print('doing optimization now',flush=True)
# create L
N = PS.size(0)
K = PS.size(1)
tt = time.time()
_K_dist = torch.ones((K, 1), dtype=torch.float64, device='cuda') # / K
if args.distribution != 'default':
marginals_argsort = torch.argsort(PS.sum(0))
if (args.dist is None) or args.diff_dist_every:
if args.distribution == 'gauss':
if args.diff_dist_per_head:
_K_dists = [(torch.randn(size=(K, 1), dtype=torch.float64, device='cuda')*args.gauss_sd + 1) * N / K
for _ in range(args.headcount)]
args.dist = _K_dists
_K_dist = _K_dists[hc]
else:
_K_dist = (torch.randn(size=(K, 1), dtype=torch.float64, device='cuda')*args.gauss_sd + 1) * N / K
_K_dist = torch.clamp(_K_dist, min=1)
args.dist = _K_dist
if args.rank == 0:
logger.info(f"distribution used: {_K_dist}")
else:
if args.diff_dist_per_head:
_K_dist = args.dist[hc]
else:
_K_dist = args.dist
_K_dist[marginals_argsort] = torch.sort(_K_dist)[0]
beta = torch.ones((N, 1), dtype=torch.float64, device='cuda') / N
PS.pow_(0.5*args.lamb)
r = 1./_K_dist
r /= r.sum()
c = 1./N
err = 1e6
_counter = 0
ones = torch.ones(N, device='cuda:0', dtype=torch.float64)
while (err > 1e-1) and (_counter < 2000):
alpha = r / torch.matmul(beta.t(), PS).t()
beta_new = c / torch.matmul(PS, alpha)
if _counter % 10 == 0:
err = torch.sum(torch.abs((beta.squeeze() / beta_new.squeeze()) - ones)).cpu().item()
beta = beta_new
_counter += 1
if args.rank == 0:
logger.info(f"error: {err}, step : {_counter}")
# inplace calculations
torch.mul(PS, beta, out=PS)
torch.mul(alpha.t(), PS, out=PS)
newL = torch.argmax(PS, 1).cuda()
# return back to obtain cost (optional)
torch.mul((1./alpha).t(), PS, out=PS)
torch.mul(PS, 1./beta, out=PS)
sol = np.nansum(torch.log(PS[torch.arange(0, len(newL)).long(), newL]).cpu().numpy())
cost = -(1. / args.lamb) * sol / N
if args.rank == 0:
logger.info(f"opt took {(time.time() - tt) / 60.} min, {_counter} iters")
return cost, newL