def assign_mtch()

in src/rime/metrics/matching.py [0:0]


def assign_mtch(score_mat, topk, C,
                constraint_type='ub', argsort_ij=None, device="cpu"):

    n_users, n_items = score_mat.shape

    if argsort_ij is None:
        argsort_ij = _argsort(score_mat, device=device)

    if constraint_type == 'ub':
        assigned_csr, _ = _assign_sorted((n_users, n_items), topk, C, argsort_ij)
    else:  # lb
        if np.isscalar(C):
            min_total_recs = min(n_users * topk, C * n_items)
            min_k = min(topk, np.ceil(min_total_recs / n_users).astype(int))
            min_C = min(C,    np.ceil(min_total_recs / n_items).astype(int))
        else:
            assert np.broadcast_to(topk, (n_users,)).sum() > C.sum(), \
                "relative only on item_rec"
            min_k = np.round(C.sum() / n_users).astype(int)
            min_C = C
        min_csr, blocked = _assign_sorted((n_users, n_items), min_k, min_C, argsort_ij)

        if topk > min_k:
            k_vec = topk - np.ravel(min_csr.sum(axis=1))
            c_vec = n_users
        else:
            k_vec = n_items
            c_vec = C - np.ravel(min_csr.sum(axis=0))

        top_off, _ = _assign_sorted((n_users, n_items), k_vec, c_vec, argsort_ij, blocked)
        assigned_csr = min_csr + top_off

    return assigned_csr