def train_analogy_regressor()

in analogy_generation.py [0:0]


def train_analogy_regressor(analogies, centroids, base_classes, trained_classifier, lr=0.1, wt=10, niter=120000, step_after=40000, batchsize=128, momentum=0.9, wd=0.0001):
    # pre-permute analogies
    permuted_analogies = analogies[np.random.permutation(analogies.shape[0])]

    # create model and init
    featdim = centroids[0].shape[1]
    model = AnalogyRegressor(featdim)
    model = model.cuda()
    trained_classifier = trained_classifier.cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, weight_decay=wd, dampening=momentum)
    loss_1 = nn.CrossEntropyLoss().cuda()
    loss_2 = nn.MSELoss().cuda()


    num_clusters_per_class = centroids[0].shape[0]
    centroid_labels = (np.array(base_classes).reshape((-1,1)) * np.ones((1, num_clusters_per_class))).reshape(-1)
    concatenated_centroids = np.concatenate(centroids, axis=0)


    start=0
    avg_loss_1 = avg_loss_2 = count = 0.0
    for i in range(niter):
        # get current batch of analogies
        stop = min(start+batchsize, permuted_analogies.shape[0])
        to_train = permuted_analogies[start:stop,:]
        optimizer.zero_grad()

        # analogy is A:B :: C:D, goal is to predict B from A, C, D
        # Y is the class label of B (and A)
        A = concatenated_centroids[to_train[:,0]]
        B = concatenated_centroids[to_train[:,1]]
        C = concatenated_centroids[to_train[:,2]]
        D = concatenated_centroids[to_train[:,3]]
        Y = centroid_labels[to_train[:,1]]

        A = Variable(torch.Tensor(A)).cuda()
        B = Variable(torch.Tensor(B)).cuda()
        C = Variable(torch.Tensor(C)).cuda()
        D = Variable(torch.Tensor(D)).cuda()
        Y = Variable(torch.LongTensor(Y.astype(int))).cuda()

        Bhat = model(A,C,D)

        lossval_2 = loss_2(Bhat, B) # simple mean squared error loss

        # classification loss
        predicted_classprobs = trained_classifier(Bhat)
        lossval_1 = loss_1(predicted_classprobs, Y)
        loss = lossval_1 + wt * lossval_2

        loss.backward()
        optimizer.step()

        avg_loss_1 = avg_loss_1 + lossval_1.data[0]
        avg_loss_2 = avg_loss_2 + lossval_2.data[0]
        count = count+1.0


        if i % 100 == 0:
            print('{:d} : {:f}, {:f}, {:f}'.format(i, avg_loss_1/count, avg_loss_2/count, count))
            avg_loss_1 = avg_loss_2 = count = 0.0

        if (i+1) % step_after == 0:
            lr = lr / 10.0
            print(lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        start = stop
        if start==permuted_analogies.shape[0]:
            start=0

    return dict(model_state=model.state_dict(), concatenated_centroids=torch.Tensor(concatenated_centroids),
            num_base_classes=len(centroids), num_clusters_per_class=num_clusters_per_class)