def train_classifier()

in analogy_generation.py [0:0]


def train_classifier(filehandle, base_classes, cachefile, networkfile, total_num_classes = 1000, lr=0.1, wd=0.0001, momentum=0.9, batchsize=1000, niter=10000):
    # either use pre-existing classifier or train one
    all_labels = filehandle['all_labels'][...]
    all_labels = all_labels.astype(int)
    all_feats = filehandle['all_feats']
    base_class_ids = np.where(np.in1d(all_labels, base_classes))[0]
    loss = nn.CrossEntropyLoss().cuda()
    model = nn.Linear(all_feats[0].size, total_num_classes).cuda()
    if os.path.isfile(cachefile):
        tmp = torch.load(cachefile)
        model.load_state_dict(tmp)
    elif os.path.isfile(networkfile):
        tmp = torch.load(networkfile)
        if 'module.classifier.bias' in tmp['state']:
            state_dict = {'weight':tmp['state']['module.classifier.weight'], 'bias':tmp['state']['module.classifier.bias']}
        else:
            model = nn.Linear(all_feats[0].size, total_num_classes, bias=False).cuda()
            state_dict = {'weight':tmp['state']['module.classifier.weight']}
        model.load_state_dict(state_dict)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, weight_decay=wd, dampening=0)
        for i in range(niter):
            optimizer.zero_grad()
            idx = np.sort(np.random.choice(base_class_ids, batchsize, replace=False))
            F = all_feats[idx,:]
            F = Variable(torch.Tensor(F)).cuda()
            L = Variable(torch.LongTensor(all_labels[idx])).cuda()
            S = model(F)
            loss_val = loss(S, L)
            loss_val.backward()
            optimizer.step()
            if i % 100 == 0:
                print('Classifier training {:d}: {:f}'.format(i, loss_val.data[0]))
        torch.save(model.state_dict(), cachefile)

    return model