def mine_analogies()

in analogy_generation.py [0:0]


def mine_analogies(centroids):
    n_clusters = centroids[0].shape[0]

    analogies = np.zeros((n_clusters*n_clusters*len(centroids),4), dtype=int)
    analogy_scores = np.zeros(analogies.shape[0])
    start=0

    I, J = np.unravel_index(np.arange(n_clusters**2), (n_clusters, n_clusters))
    # for every class
    for i, c_i in enumerate(centroids):

        # get normalized difference vectors between cluster centers
        diff_i = get_difference_vectors(c_i)
        diff_i_t = torch.Tensor(diff_i).cuda()


        bestdots = np.zeros(diff_i.shape[0])
        bestdotidx = np.zeros((diff_i.shape[0],2),dtype=int)

        # for every other class
        for j, c_j in enumerate(centroids):
            if i==j:
                continue
            print(i,j)

            # get normalized difference vectors
            diff_j = get_difference_vectors(c_j)
            diff_j = torch.Tensor(diff_j).cuda()

            #compute cosine distance and take the maximum
            dots = diff_i_t.mm(diff_j.transpose(0,1))
            maxdots, argmaxdots = dots.max(1)
            maxdots = maxdots.cpu().numpy().reshape(-1)
            argmaxdots = argmaxdots.cpu().numpy().reshape(-1)

            # if maximum is better than best seen so far, update
            betteridx = maxdots>bestdots
            bestdots[betteridx] = maxdots[betteridx]
            bestdotidx[betteridx,0] = j*n_clusters + I[argmaxdots[betteridx]]
            bestdotidx[betteridx,1] = j*n_clusters + J[argmaxdots[betteridx]]


        # store discovered analogies
        stop = start+diff_i.shape[0]
        analogies[start : stop,0]=i*n_clusters + I
        analogies[start : stop,1]=i*n_clusters + J
        analogies[start : stop,2:] = bestdotidx
        analogy_scores[start : stop] = bestdots
        start = stop

    #prune away trivial analogies
    good_analogies = (analogy_scores>0) & (analogies[:,0]!=analogies[:,1]) & (analogies[:,2]!=analogies[:,3])
    return analogies[good_analogies,:], analogy_scores[good_analogies]