def match_order()

in src/sk_utils.py [0:0]


def match_order(args, emb1, emb2_in, W2, steps=50000, restarts=2, logger=None):
    fin_perm = torch.arange(0, len(W2.bias.data)).cuda()
    if args.rank == 0:
        assert type(W2) == torch.nn.modules.linear.Linear
        K = emb1.shape[1]
        def c(a, b):
            return (torch.abs(a - b)).sum(0).sum(0)
        last_iter = 0
        cost = c(emb1, emb2_in)
        best_cost = cost
        logger.info(f'initial cost: {cost:.1f}')
        for retries in range(restarts):
            cost_try = cost.item()
            perm = torch.arange(0, K)
            emb2 = emb2_in.clone().detach()
            for _iter in range(steps):
                # what would happen if we switch cluster i with j in emb2
                [i, j] = np.random.choice(K, 2, replace=False)
                current = c(emb1[:,i], emb2[:,i])  + c(emb1[:,j], emb2[:,j])
                future =  c(emb1[:,i], emb2[:,j])  + c(emb1[:,j], emb2[:,i])
                delta = current - future
                if delta > 0:
                    # switch i and j
                    emb2[:,j], emb2[:,i] = emb2[:,i].clone().detach(), emb2[:,j].clone().detach()
                    cost_try -= delta
                    _i = int(perm[i])
                    perm[i] = int(perm[j])
                    perm[j] = _i
                    last_iter = _iter
                if _iter - last_iter > 1000:
                    break

            cost_try = c(emb1, emb2_in[:, perm])
            logger.info(f"cost of this try: {cost_try:.2f}")
            if cost_try < best_cost:
                best_cost = cost_try
                fin_perm = perm.cuda()
    dist.broadcast(fin_perm, 0)
    fin_perm = fin_perm.cpu()
    if args.rank == 0:
        logger.info(f"final cost: {best_cost:.2f}")
    W2.bias.data = W2.bias.data[fin_perm]
    W2.weight.data = W2.weight.data[fin_perm]