def train_matching_network()

in matching_network.py [0:0]


def train_matching_network(model, file_handle, base_classes, m=389, n=10, initlr=0.1, momentum=0.9, wd=0.001, step_after=20000, niter=60000):

    model = model.cuda()
    lr = initlr
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, dampening=momentum, weight_decay = wd)

    loss_fn = nn.NLLLoss()
    all_labels = file_handle['all_labels'][...]

    total_loss = 0.0
    loss_count = 0.0
    for it in range(niter):
        optimizer.zero_grad()

        rand_labels = np.random.choice(base_classes, m, replace=False)
        num = np.random.choice(n, m)+1
        batchsize = int(np.sum(num))

        train_feats = torch.zeros(batchsize, model.feat_dim)
        train_Y = torch.zeros(batchsize, m)
        test_feats = torch.zeros(m, model.feat_dim)
        test_labels = torch.range(0,m-1)

        count=0
        for j in range(m):
            idx = np.where(all_labels==rand_labels[j])[0]
            train_idx = np.sort(np.random.choice(idx, num[j], replace=False))
            test_idx = np.random.choice(idx)

            F_tmp = file_handle['all_feats'][list(train_idx)]
            train_feats[count:count+num[j]] = torch.Tensor(F_tmp)
            train_Y[count:count+num[j],j] = 1
            F_tmp = file_handle['all_feats'][test_idx]
            test_feats[j] = torch.Tensor(F_tmp)
            count = count+num[j]

        train_feats = Variable(train_feats.cuda())
        train_Y = Variable(train_Y.cuda())
        test_feats = Variable(test_feats.cuda())
        test_labels = Variable(test_labels.long().cuda())

        logprob = model(test_feats, train_feats, train_Y)
        loss = loss_fn(logprob, test_labels)
        loss.backward()
        optimizer.step()

        if (it+1) % step_after == 0:
            lr = lr / 10
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        total_loss = total_loss + loss.data[0]
        loss_count = loss_count + 1

        if (it+1)%1 == 0:
            print('{:d}:{:f}'.format(it, total_loss / loss_count))
            total_loss = 0.0
            loss_count = 0.0

    return model