in matching_network.py [0:0]
def train_matching_network(model, file_handle, base_classes, m=389, n=10, initlr=0.1, momentum=0.9, wd=0.001, step_after=20000, niter=60000):
model = model.cuda()
lr = initlr
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, dampening=momentum, weight_decay = wd)
loss_fn = nn.NLLLoss()
all_labels = file_handle['all_labels'][...]
total_loss = 0.0
loss_count = 0.0
for it in range(niter):
optimizer.zero_grad()
rand_labels = np.random.choice(base_classes, m, replace=False)
num = np.random.choice(n, m)+1
batchsize = int(np.sum(num))
train_feats = torch.zeros(batchsize, model.feat_dim)
train_Y = torch.zeros(batchsize, m)
test_feats = torch.zeros(m, model.feat_dim)
test_labels = torch.range(0,m-1)
count=0
for j in range(m):
idx = np.where(all_labels==rand_labels[j])[0]
train_idx = np.sort(np.random.choice(idx, num[j], replace=False))
test_idx = np.random.choice(idx)
F_tmp = file_handle['all_feats'][list(train_idx)]
train_feats[count:count+num[j]] = torch.Tensor(F_tmp)
train_Y[count:count+num[j],j] = 1
F_tmp = file_handle['all_feats'][test_idx]
test_feats[j] = torch.Tensor(F_tmp)
count = count+num[j]
train_feats = Variable(train_feats.cuda())
train_Y = Variable(train_Y.cuda())
test_feats = Variable(test_feats.cuda())
test_labels = Variable(test_labels.long().cuda())
logprob = model(test_feats, train_feats, train_Y)
loss = loss_fn(logprob, test_labels)
loss.backward()
optimizer.step()
if (it+1) % step_after == 0:
lr = lr / 10
for param_group in optimizer.param_groups:
param_group['lr'] = lr
total_loss = total_loss + loss.data[0]
loss_count = loss_count + 1
if (it+1)%1 == 0:
print('{:d}:{:f}'.format(it, total_loss / loss_count))
total_loss = 0.0
loss_count = 0.0
return model