in main_deepclusterv2.py [0:0]
def cluster_memory(model, local_memory_index, local_memory_embeddings, size_dataset, nmb_kmeans_iters=10):
j = 0
assignments = -100 * torch.ones(len(args.nmb_prototypes), size_dataset).long()
with torch.no_grad():
for i_K, K in enumerate(args.nmb_prototypes):
# run distributed k-means
# init centroids with elements from memory bank of rank 0
centroids = torch.empty(K, args.feat_dim).cuda(non_blocking=True)
if args.rank == 0:
random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K]
assert len(random_idx) >= K, "please reduce the number of centroids"
centroids = local_memory_embeddings[j][random_idx]
dist.broadcast(centroids, 0)
for n_iter in range(nmb_kmeans_iters + 1):
# E step
dot_products = torch.mm(local_memory_embeddings[j], centroids.t())
_, local_assignments = dot_products.max(dim=1)
# finish
if n_iter == nmb_kmeans_iters:
break
# M step
where_helper = get_indices_sparse(local_assignments.cpu().numpy())
counts = torch.zeros(K).cuda(non_blocking=True).int()
emb_sums = torch.zeros(K, args.feat_dim).cuda(non_blocking=True)
for k in range(len(where_helper)):
if len(where_helper[k][0]) > 0:
emb_sums[k] = torch.sum(
local_memory_embeddings[j][where_helper[k][0]],
dim=0,
)
counts[k] = len(where_helper[k][0])
dist.all_reduce(counts)
mask = counts > 0
dist.all_reduce(emb_sums)
centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)
# normalize centroids
centroids = nn.functional.normalize(centroids, dim=1, p=2)
getattr(model.module.prototypes, "prototypes" + str(i_K)).weight.copy_(centroids)
# gather the assignments
assignments_all = torch.empty(args.world_size, local_assignments.size(0),
dtype=local_assignments.dtype, device=local_assignments.device)
assignments_all = list(assignments_all.unbind(0))
dist_process = dist.all_gather(assignments_all, local_assignments, async_op=True)
dist_process.wait()
assignments_all = torch.cat(assignments_all).cpu()
# gather the indexes
indexes_all = torch.empty(args.world_size, local_memory_index.size(0),
dtype=local_memory_index.dtype, device=local_memory_index.device)
indexes_all = list(indexes_all.unbind(0))
dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True)
dist_process.wait()
indexes_all = torch.cat(indexes_all).cpu()
# log assignments
assignments[i_K][indexes_all] = assignments_all
# next memory bank to use
j = (j + 1) % len(args.crops_for_assign)
return assignments