def eval_reconstruction()

in hype/graph.py [0:0]


def eval_reconstruction(adj, model, workers=1, progress=False):
    '''
    Reconstruction evaluation.  For each object, rank its neighbors by distance

    Args:
        adj (dict[int, set[int]]): Adjacency list mapping objects to its neighbors
        lt (torch.Tensor[N, dim]): Embedding table with `N` embeddings and `dim`
            dimensionality
        distfn ((torch.Tensor, torch.Tensor) -> torch.Tensor): distance function.
        workers (int): number of workers to use
    '''
    objects = np.array(list(adj.keys()))
    if workers > 1:
        with ThreadPool(workers) as pool:
            f = partial(reconstruction_worker, adj, model)
            results = pool.map(f, np.array_split(objects, workers))
            results = np.array(results).sum(axis=0).astype(float)
    else:
        results = reconstruction_worker(adj, model, objects, progress)
    return float(results[0]) / results[1], float(results[2]) / results[3]