def kmeans()

in torch_kmeans.py [0:0]


def kmeans(x, k, niter=1, batchsize=1000):
    batchsize = min(batchsize, x.shape[0])

    nsamples = x.shape[0]
    ndims = x.shape[1]

    x2 = np.sum(x**2, axis=1)
    centroids = np.random.randn(k, ndims)
    centroidnorm = np.sqrt(np.sum(centroids**2, axis=1, keepdims=True))
    centroids = centroids / centroidnorm
    totalcounts = np.zeros(k)

    for i in range(niter):
        c2 = np.sum(centroids**2, axis=1,keepdims=True)*0.5
        summation = np.zeros((k, ndims))
        counts = np.zeros(k)
        loss = 0

        for j in range(0, nsamples, batchsize):
            lastj = min(j+batchsize, nsamples)
            batch = x[j:lastj]
            m = batch.shape[0]

            tmp = np.dot(centroids, batch.T)
            tmp = tmp - c2
            val = np.max(tmp,0)
            labels = np.argmax(tmp,0)
            loss = loss + np.sum(np.sum(x2[j:lastj])*0.5 - val)

            S = np.zeros((k, m))
            S[labels, np.arange(m)] = 1
            summation = summation + np.dot(S, batch)
            counts = counts + np.sum(S, axis=1)

        for j in range(k):
            if counts[j]>0:
                centroids[j] = summation[j] / counts[j]

        totalcounts = totalcounts + counts
        for j in range(k):
            if totalcounts[j] == 0:
                idx = np.random.choice(nsamples)
                centroids[j] = x[idx]



    return centroids