in main_swav.py [0:0]
def distributed_sinkhorn(out):
Q = torch.exp(out / args.epsilon).t() # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * args.world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q
for it in range(args.sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
dist.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the colomns must sum to 1 so that Q is an assignment
return Q.t()