in src/sk_utils.py [0:0]
def match_order(args, emb1, emb2_in, W2, steps=50000, restarts=2, logger=None):
fin_perm = torch.arange(0, len(W2.bias.data)).cuda()
if args.rank == 0:
assert type(W2) == torch.nn.modules.linear.Linear
K = emb1.shape[1]
def c(a, b):
return (torch.abs(a - b)).sum(0).sum(0)
last_iter = 0
cost = c(emb1, emb2_in)
best_cost = cost
logger.info(f'initial cost: {cost:.1f}')
for retries in range(restarts):
cost_try = cost.item()
perm = torch.arange(0, K)
emb2 = emb2_in.clone().detach()
for _iter in range(steps):
# what would happen if we switch cluster i with j in emb2
[i, j] = np.random.choice(K, 2, replace=False)
current = c(emb1[:,i], emb2[:,i]) + c(emb1[:,j], emb2[:,j])
future = c(emb1[:,i], emb2[:,j]) + c(emb1[:,j], emb2[:,i])
delta = current - future
if delta > 0:
# switch i and j
emb2[:,j], emb2[:,i] = emb2[:,i].clone().detach(), emb2[:,j].clone().detach()
cost_try -= delta
_i = int(perm[i])
perm[i] = int(perm[j])
perm[j] = _i
last_iter = _iter
if _iter - last_iter > 1000:
break
cost_try = c(emb1, emb2_in[:, perm])
logger.info(f"cost of this try: {cost_try:.2f}")
if cost_try < best_cost:
best_cost = cost_try
fin_perm = perm.cuda()
dist.broadcast(fin_perm, 0)
fin_perm = fin_perm.cpu()
if args.rank == 0:
logger.info(f"final cost: {best_cost:.2f}")
W2.bias.data = W2.bias.data[fin_perm]
W2.weight.data = W2.weight.data[fin_perm]