in src/dico_builder.py [0:0]
def get_candidates(emb1, emb2, params):
"""
Get best translation pairs candidates.
"""
bs = 128
all_scores = []
all_targets = []
# number of source words to consider
n_src = emb1.size(0)
if params.dico_max_rank > 0 and not params.dico_method.startswith('invsm_beta_'):
n_src = min(params.dico_max_rank, n_src)
# nearest neighbors
if params.dico_method == 'nn':
# for every source word
for i in range(0, n_src, bs):
# compute target words scores
scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1)
best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
# update scores / potential targets
all_scores.append(best_scores.cpu())
all_targets.append(best_targets.cpu())
all_scores = torch.cat(all_scores, 0)
all_targets = torch.cat(all_targets, 0)
# inverted softmax
elif params.dico_method.startswith('invsm_beta_'):
beta = float(params.dico_method[len('invsm_beta_'):])
# for every target word
for i in range(0, emb2.size(0), bs):
# compute source words scores
scores = emb1.mm(emb2[i:i + bs].transpose(0, 1))
scores.mul_(beta).exp_()
scores.div_(scores.sum(0, keepdim=True).expand_as(scores))
best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
# update scores / potential targets
all_scores.append(best_scores.cpu())
all_targets.append((best_targets + i).cpu())
all_scores = torch.cat(all_scores, 1)
all_targets = torch.cat(all_targets, 1)
all_scores, best_targets = all_scores.topk(2, dim=1, largest=True, sorted=True)
all_targets = all_targets.gather(1, best_targets)
# contextual dissimilarity measure
elif params.dico_method.startswith('csls_knn_'):
knn = params.dico_method[len('csls_knn_'):]
assert knn.isdigit()
knn = int(knn)
# average distances to k nearest neighbors
average_dist1 = torch.from_numpy(get_nn_avg_dist(emb2, emb1, knn))
average_dist2 = torch.from_numpy(get_nn_avg_dist(emb1, emb2, knn))
average_dist1 = average_dist1.type_as(emb1)
average_dist2 = average_dist2.type_as(emb2)
# for every source word
for i in range(0, n_src, bs):
# compute target words scores
scores = emb2.mm(emb1[i:min(n_src, i + bs)].transpose(0, 1)).transpose(0, 1)
scores.mul_(2)
scores.sub_(average_dist1[i:min(n_src, i + bs)][:, None] + average_dist2[None, :])
best_scores, best_targets = scores.topk(2, dim=1, largest=True, sorted=True)
# update scores / potential targets
all_scores.append(best_scores.cpu())
all_targets.append(best_targets.cpu())
all_scores = torch.cat(all_scores, 0)
all_targets = torch.cat(all_targets, 0)
all_pairs = torch.cat([
torch.arange(0, all_targets.size(0)).long().unsqueeze(1),
all_targets[:, 0].unsqueeze(1)
], 1)
# sanity check
assert all_scores.size() == all_pairs.size() == (n_src, 2)
# sort pairs by score confidence
diff = all_scores[:, 0] - all_scores[:, 1]
reordered = diff.sort(0, descending=True)[1]
all_scores = all_scores[reordered]
all_pairs = all_pairs[reordered]
# max dico words rank
if params.dico_max_rank > 0:
selected = all_pairs.max(1)[0] <= params.dico_max_rank
mask = selected.unsqueeze(1).expand_as(all_scores).clone()
all_scores = all_scores.masked_select(mask).view(-1, 2)
all_pairs = all_pairs.masked_select(mask).view(-1, 2)
# max dico size
if params.dico_max_size > 0:
all_scores = all_scores[:params.dico_max_size]
all_pairs = all_pairs[:params.dico_max_size]
# min dico size
diff = all_scores[:, 0] - all_scores[:, 1]
if params.dico_min_size > 0:
diff[:params.dico_min_size] = 1e9
# confidence threshold
if params.dico_threshold > 0:
mask = diff > params.dico_threshold
logger.info("Selected %i / %i pairs above the confidence threshold." % (mask.sum(), diff.size(0)))
mask = mask.unsqueeze(1).expand_as(all_pairs).clone()
all_pairs = all_pairs.masked_select(mask).view(-1, 2)
return all_pairs