def mips()

in paq/retrievers/retrieve.py [0:0]


def mips(index, queries, top_k, n_queries_to_parallelize=256):
    t = time.time()
    all_top_indices = None
    all_top_scores = None

    _mips = _get_mips_function(index)

    for mb in range(0, len(queries), n_queries_to_parallelize):
        query_batch = queries[mb:mb + n_queries_to_parallelize].float()
        scores, top_indices = _mips(index, query_batch, top_k)

        all_top_indices = top_indices if all_top_indices is None else np.concatenate([all_top_indices, top_indices])
        all_top_scores = scores if all_top_scores is None else np.concatenate([all_top_scores, scores])

        delta = time.time() - t
        logger.info(
            f'{len(all_top_indices)}/ {len(queries)} queries searched in {delta:04f} '
            f'seconds ({len(all_top_indices) / delta} per second)')

    assert len(all_top_indices) == len(queries)

    delta = time.time() - t
    logger.info(f'Index searched in {delta:04f} seconds ({len(queries) / delta} per second)')
    return all_top_indices, all_top_scores