def prepare_merges()

in data_measurements/embeddings/embeddings.py [0:0]


def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5):
    """
    Prepares an initial list of merges for hierarchical
    clustering. First compute the `approx_neighbors` nearest neighbors,
    then propose a merge for any two points that are closer than `low_thres`

    Note that if a point has more than `approx_neighbors` neighbors
    closer than `low_thres`, this approach will miss some of those merges

    Args:
        embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD
        batch_size (int): compute nearest neighbors of `batch_size` points at a time
        approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point
        low_thres (float): only return merges where the dot product is greater than `low_thres`
    Returns:
        torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2
        torch.Tensor: merge scores - dimension M
    """
    top_idx_pre = torch.cat(
        [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
    )
    top_val_all = torch.Tensor(0, approx_neighbors)
    top_idx_all = torch.LongTensor(0, approx_neighbors)
    n_batches = math.ceil(len(embeddings) / batch_size)
    for b in tqdm(range(n_batches)):
        # TODO: batch across second dimension
        cos_scores = torch.mm(
            embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
        )
        for i in range(cos_scores.shape[0]):
            cos_scores[i, (b * batch_size) + i :] = -1
        top_val_large, top_idx_large = cos_scores.topk(
            k=approx_neighbors, dim=-1, largest=True
        )
        top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
        top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
        max_neighbor_dist = top_val_large[:, -1].max().item()
        if max_neighbor_dist > low_thres:
            print(
                f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}"
            )

    all_merges = torch.cat(
        [
            top_idx_pre[top_val_all > low_thres][:, None],
            top_idx_all[top_val_all > low_thres][:, None],
        ],
        dim=1,
    )
    all_merge_scores = top_val_all[top_val_all > low_thres]

    return (all_merges, all_merge_scores)