in alignment/unsup_multialign.py [0:0]
def align(EMB, TRANS, lglist, args):
nmax, l = args.maxload, len(lglist)
# create a list of language pairs to sample from
# (default == higher probability to pick a language pair contianing the pivot)
# if --uniform: uniform probability of picking a language pair
samples = []
for i in range(l):
for j in range(l):
if j == i :
continue
if j > 0 and args.uniform == False:
samples.append((0,j))
if i > 0 and args.uniform == False:
samples.append((i,0))
samples.append((i,j))
# optimization of the l2 loss
print('start optimizing L2 loss')
lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter
for epoch in range(nepoch):
print("start epoch %d / %d"%(epoch+1, nepoch))
ones = np.ones(bsz)
f, fold, nb, lr = 0.0, 0.0, 0.0, lr0
for it in range(niter):
if it > 1 and f > fold + 1e-3:
lr /= 2
if lr < .05:
break
fold = f
f, nb = 0.0, 0.0
for k in range(100 * (l-1)):
(i,j) = random.choice(samples)
embi = EMB[i][np.random.permutation(nmax)[:bsz], :]
embj = EMB[j][np.random.permutation(nmax)[:bsz], :]
perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3)
grad = np.linalg.multi_dot([embi.T, perm, embj])
f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0]
nb += 1
if i > 0:
TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j]))
if j > 0:
TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i]))
print("iter %d / %d - epoch %d - loss: %.5f lr: %.4f" % (it, niter, epoch+1, f / nb , lr))
print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
niter, bsz = max(int(niter/2),2), min(1000, bsz * 2)
#end for epoch in range(nepoch):
# optimization of the RCSLS loss
print('start optimizing RCSLS loss')
f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr
for epoch in range(args.altepoch):
if epoch > 1 and f-fold > -1e-4 * abs(fold):
lr/= 2
if lr < 1e-1:
break
fold = f
f, nb = 0.0, 0.0
for k in range(round(nmax / args.altbsz) * 10 * (l-1)):
(i,j) = random.choice(samples)
sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False)
embi = EMB[i][sgdidx, :]
embj = EMB[j][:nmax, :]
# crude alignment approximation:
T = np.dot(TRANS[i], TRANS[j].T)
scores = np.linalg.multi_dot([embi, T, embj.T])
perm = np.zeros_like(scores)
perm[np.arange(len(scores)), scores.argmax(1)] = 1
embj = np.dot(perm, embj)
# normalization over a subset of embeddings for speed up
fi, grad = rcsls(embi, embj, embi, embj, T.T)
f += fi
nb += 1
if i > 0:
TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j]))
if j > 0:
TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i]))
print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
#end for epoch in range(args.altepoch):
return TRANS