def inference()

in align/test.py [0:0]


def inference(simalign, probs, threshold):
    n, m = probs.shape
    ids = probs.view(-1).argsort(descending=True)
    f = lambda x, m: (x.item()//m, x.item()%m)
    src2trg = collections.defaultdict(set)
    trg2src = collections.defaultdict(set)
    results = set()
    for pair in simalign.split():
        x, y = pair.split('-')
        x = int(x)
        y = int(y)
        src2trg[x].add(y)
        trg2src[y].add(x)
        results.add((x, y))
    for idx in ids:
        x, y = f(idx, m)
        if probs[x, y] < threshold:  # too low similarity
            break
        if (x not in src2trg) and (y not in trg2src): # perfect company, keep
            src2trg[x].add(y)
            trg2src[y].add(x)
            results.add((x, y))
        elif (x in src2trg) and (y in trg2src):  # both have other companies, skip
            continue
        elif x in src2trg:  # x has company, but y is still addable 
            if y == max(src2trg[x]) + 1 or y == min(src2trg[x]) - 1:
                src2trg[x].add(y)
                trg2src[y].add(x)
                results.add((x, y))
        else:
            if x == max(trg2src[y]) + 1 or x == min(trg2src[y]) - 1:
                src2trg[x].add(y)
                trg2src[y].add(x)
                results.add((x, y))
    results = ' '.join([f'{x}-{y}' for x, y in sorted(results)])
    return results