in analogy_generation.py [0:0]
def train_classifier(filehandle, base_classes, cachefile, networkfile, total_num_classes = 1000, lr=0.1, wd=0.0001, momentum=0.9, batchsize=1000, niter=10000):
# either use pre-existing classifier or train one
all_labels = filehandle['all_labels'][...]
all_labels = all_labels.astype(int)
all_feats = filehandle['all_feats']
base_class_ids = np.where(np.in1d(all_labels, base_classes))[0]
loss = nn.CrossEntropyLoss().cuda()
model = nn.Linear(all_feats[0].size, total_num_classes).cuda()
if os.path.isfile(cachefile):
tmp = torch.load(cachefile)
model.load_state_dict(tmp)
elif os.path.isfile(networkfile):
tmp = torch.load(networkfile)
if 'module.classifier.bias' in tmp['state']:
state_dict = {'weight':tmp['state']['module.classifier.weight'], 'bias':tmp['state']['module.classifier.bias']}
else:
model = nn.Linear(all_feats[0].size, total_num_classes, bias=False).cuda()
state_dict = {'weight':tmp['state']['module.classifier.weight']}
model.load_state_dict(state_dict)
else:
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, weight_decay=wd, dampening=0)
for i in range(niter):
optimizer.zero_grad()
idx = np.sort(np.random.choice(base_class_ids, batchsize, replace=False))
F = all_feats[idx,:]
F = Variable(torch.Tensor(F)).cuda()
L = Variable(torch.LongTensor(all_labels[idx])).cuda()
S = model(F)
loss_val = loss(S, L)
loss_val.backward()
optimizer.step()
if i % 100 == 0:
print('Classifier training {:d}: {:f}'.format(i, loss_val.data[0]))
torch.save(model.state_dict(), cachefile)
return model