in train.py [0:0]
def triplet_optimize(xt, gt_nn, net, args, val_func):
"""
train a triplet loss on the training set xt (a numpy array)
gt_nn: ground-truth nearest neighbors in input space
net: network to optimize
args: various runtime arguments
val_func: callback called periodically to evaluate the network
"""
lr_schedule = [float(x.rstrip().lstrip()) for x in args.lr_schedule.split(",")]
assert args.epochs % len(lr_schedule) == 0
lr_schedule = repeat(lr_schedule, args.epochs // len(lr_schedule))
print("Lr schedule", lr_schedule)
N, kpos = gt_nn.shape
if args.quantizer_train != "":
assert args.quantizer_train.startswith("zn_")
r2 = int(args.quantizer_train.split("_")[1])
qt = StraightThroughQuantizer(Zn(r2))
else:
qt = lambda x: x
xt_var = torch.from_numpy(xt).to(args.device)
# prepare optimizer
optimizer = optim.SGD(net.parameters(), lr_schedule[0], momentum=args.momentum)
pdist = nn.PairwiseDistance(2)
all_logs = []
for epoch in range(args.epochs):
# Update learning rate
args.lr = lr_schedule[epoch]
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
t0 = time.time()
# Sample positives for triplet
rank_pos = np.random.choice(kpos, size=N)
positive_idx = gt_nn[np.arange(N), rank_pos]
# Sample negatives for triplet
net.eval()
print(" Forward pass")
xl_net = forward_pass(net, xt, 1024)
print(" Distances")
I = get_nearestneighbors(xl_net, qt(xl_net), args.rank_negative, args.device, needs_exact=False)
negative_idx = I[:, -1]
# training pass
print(" Train")
net.train()
avg_triplet, avg_uniform, avg_loss = 0, 0, 0
offending = idx_batch = 0
# process dataset in a random order
perm = np.random.permutation(N)
t1 = time.time()
for i0 in range(0, N, args.batch_size):
i1 = min(i0 + args.batch_size, N)
n = i1 - i0
data_idx = perm[i0:i1]
# anchor, positives, negatives
ins = xt_var[data_idx]
pos = xt_var[positive_idx[data_idx]]
neg = xt_var[negative_idx[data_idx]]
# do the forward pass (+ record gradients)
ins, pos, neg = net(ins), net(pos), net(neg)
pos, neg = qt(pos), qt(neg)
# triplet loss
per_point_loss = pdist(ins, pos) - pdist(ins, neg)
per_point_loss = F.relu(per_point_loss)
loss_triplet = per_point_loss.mean()
offending += torch.sum(per_point_loss.data > 0).item()
# entropy loss
I = pairwise_NNs_inner(ins.data)
distances = pdist(ins, ins[I])
loss_uniform = - torch.log(n * distances).mean()
# combined loss
loss = loss_triplet + args.lambda_uniform * loss_uniform
# collect some stats
avg_triplet += loss_triplet.data.item()
avg_uniform += loss_uniform.data.item()
avg_loss += loss.data.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
idx_batch += 1
avg_triplet /= idx_batch
avg_uniform /= idx_batch
avg_loss /= idx_batch
logs = {
'epoch': epoch,
'loss_triplet': avg_triplet,
'loss_uniform': avg_uniform,
'loss': avg_loss,
'offending': offending,
'lr': args.lr
}
all_logs.append(logs)
t2 = time.time()
# maybe perform a validation run
if (epoch + 1) % args.val_freq == 0:
logs['val'] = val_func(net, epoch, args, all_logs)
t3 = time.time()
# synthetic logging
print ('epoch %d, times: [hn %.2f s epoch %.2f s val %.2f s]'
' lr = %f'
' loss = %g = %g + lam * %g, offending %d' % (
epoch, t1 - t0, t2 - t1, t3 - t2,
args.lr,
avg_loss, avg_triplet, avg_uniform, offending
))
logs['times'] = (t1 - t0, t2 - t1, t3 - t2)
return all_logs