in data_measurements/embeddings/embeddings.py [0:0]
def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5):
"""
Prepares an initial list of merges for hierarchical
clustering. First compute the `approx_neighbors` nearest neighbors,
then propose a merge for any two points that are closer than `low_thres`
Note that if a point has more than `approx_neighbors` neighbors
closer than `low_thres`, this approach will miss some of those merges
Args:
embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD
batch_size (int): compute nearest neighbors of `batch_size` points at a time
approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point
low_thres (float): only return merges where the dot product is greater than `low_thres`
Returns:
torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2
torch.Tensor: merge scores - dimension M
"""
top_idx_pre = torch.cat(
[torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
)
top_val_all = torch.Tensor(0, approx_neighbors)
top_idx_all = torch.LongTensor(0, approx_neighbors)
n_batches = math.ceil(len(embeddings) / batch_size)
for b in tqdm(range(n_batches)):
# TODO: batch across second dimension
cos_scores = torch.mm(
embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
)
for i in range(cos_scores.shape[0]):
cos_scores[i, (b * batch_size) + i :] = -1
top_val_large, top_idx_large = cos_scores.topk(
k=approx_neighbors, dim=-1, largest=True
)
top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
max_neighbor_dist = top_val_large[:, -1].max().item()
if max_neighbor_dist > low_thres:
print(
f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}"
)
all_merges = torch.cat(
[
top_idx_pre[top_val_all > low_thres][:, None],
top_idx_all[top_val_all > low_thres][:, None],
],
dim=1,
)
all_merge_scores = top_val_all[top_val_all > low_thres]
return (all_merges, all_merge_scores)