in align/test.py [0:0]
def inference(simalign, probs, threshold):
n, m = probs.shape
ids = probs.view(-1).argsort(descending=True)
f = lambda x, m: (x.item()//m, x.item()%m)
src2trg = collections.defaultdict(set)
trg2src = collections.defaultdict(set)
results = set()
for pair in simalign.split():
x, y = pair.split('-')
x = int(x)
y = int(y)
src2trg[x].add(y)
trg2src[y].add(x)
results.add((x, y))
for idx in ids:
x, y = f(idx, m)
if probs[x, y] < threshold: # too low similarity
break
if (x not in src2trg) and (y not in trg2src): # perfect company, keep
src2trg[x].add(y)
trg2src[y].add(x)
results.add((x, y))
elif (x in src2trg) and (y in trg2src): # both have other companies, skip
continue
elif x in src2trg: # x has company, but y is still addable
if y == max(src2trg[x]) + 1 or y == min(src2trg[x]) - 1:
src2trg[x].add(y)
trg2src[y].add(x)
results.add((x, y))
else:
if x == max(trg2src[y]) + 1 or x == min(trg2src[y]) - 1:
src2trg[x].add(y)
trg2src[y].add(x)
results.add((x, y))
results = ' '.join([f'{x}-{y}' for x, y in sorted(results)])
return results