in src/rime/metrics/matching.py [0:0]
def assign_mtch(score_mat, topk, C,
constraint_type='ub', argsort_ij=None, device="cpu"):
n_users, n_items = score_mat.shape
if argsort_ij is None:
argsort_ij = _argsort(score_mat, device=device)
if constraint_type == 'ub':
assigned_csr, _ = _assign_sorted((n_users, n_items), topk, C, argsort_ij)
else: # lb
if np.isscalar(C):
min_total_recs = min(n_users * topk, C * n_items)
min_k = min(topk, np.ceil(min_total_recs / n_users).astype(int))
min_C = min(C, np.ceil(min_total_recs / n_items).astype(int))
else:
assert np.broadcast_to(topk, (n_users,)).sum() > C.sum(), \
"relative only on item_rec"
min_k = np.round(C.sum() / n_users).astype(int)
min_C = C
min_csr, blocked = _assign_sorted((n_users, n_items), min_k, min_C, argsort_ij)
if topk > min_k:
k_vec = topk - np.ravel(min_csr.sum(axis=1))
c_vec = n_users
else:
k_vec = n_items
c_vec = C - np.ravel(min_csr.sum(axis=0))
top_off, _ = _assign_sorted((n_users, n_items), k_vec, c_vec, argsort_ij, blocked)
assigned_csr = min_csr + top_off
return assigned_csr