def align()

in alignment/unsup_multialign.py [0:0]


def align(EMB, TRANS, lglist, args):
    nmax, l = args.maxload, len(lglist)
    # create a list of language pairs to sample from
    # (default == higher probability to pick a language pair contianing the pivot)
    # if --uniform: uniform probability of picking a language pair
    samples = []
    for i in range(l):
        for j in range(l):
            if j == i :
                continue
            if j > 0 and args.uniform == False:
                samples.append((0,j))
            if i > 0 and args.uniform == False:
                samples.append((i,0))
            samples.append((i,j))

    # optimization of the l2 loss
    print('start optimizing L2 loss')
    lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter
    for epoch in range(nepoch):
        print("start epoch %d / %d"%(epoch+1, nepoch))
        ones = np.ones(bsz)
        f, fold, nb, lr = 0.0, 0.0, 0.0, lr0
        for it in range(niter):
            if it > 1 and f > fold + 1e-3:
                lr /= 2
            if lr < .05:
                break
            fold = f
            f, nb = 0.0, 0.0
            for k in range(100 *  (l-1)):
                (i,j) = random.choice(samples)
                embi = EMB[i][np.random.permutation(nmax)[:bsz], :]
                embj = EMB[j][np.random.permutation(nmax)[:bsz], :]
                perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3)
                grad = np.linalg.multi_dot([embi.T, perm, embj])
                f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0]
                nb += 1
                if i > 0:
                    TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j]))
                if j > 0:
                    TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i]))
            print("iter %d / %d - epoch %d - loss: %.5f  lr: %.4f" % (it, niter, epoch+1, f / nb , lr))
        print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
        niter, bsz = max(int(niter/2),2), min(1000, bsz * 2)
    #end for epoch in range(nepoch):

    # optimization of the RCSLS loss
    print('start optimizing RCSLS loss')
    f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr
    for epoch in range(args.altepoch):
        if epoch > 1  and f-fold > -1e-4 * abs(fold):
            lr/= 2
        if lr < 1e-1:
            break
        fold = f
        f, nb = 0.0, 0.0
        for k in range(round(nmax / args.altbsz) * 10 * (l-1)):
            (i,j) = random.choice(samples)
            sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False)
            embi = EMB[i][sgdidx, :]
            embj = EMB[j][:nmax, :]
            # crude alignment approximation:
            T = np.dot(TRANS[i], TRANS[j].T)
            scores = np.linalg.multi_dot([embi, T, embj.T])
            perm = np.zeros_like(scores)
            perm[np.arange(len(scores)), scores.argmax(1)] = 1
            embj = np.dot(perm, embj)
            # normalization over a subset of embeddings for speed up
            fi, grad = rcsls(embi, embj, embi, embj, T.T)
            f += fi
            nb += 1
            if i > 0:
                TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j]))
            if j > 0:
                TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i]))
        print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
    #end for epoch in range(args.altepoch):
    return TRANS