in train_func.py [0:0]
def train(args, extr, clf, loss_fn, device, train_loader, optimizer, epoch, verbose=True):
if extr is not None:
extr.train()
clf.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
if extr is not None:
output = clf(extr(data))
if len(output) == 3:
output = output[0]
else:
output = clf(data)
loss = loss_fn(output, target)
if args.lam > 0:
if extr is not None:
loss += args.lam * params_to_vec(extr.parameters()).pow(2).sum() / 2
loss += args.lam * params_to_vec(clf.parameters()).pow(2).sum() / 2
loss.backward()
optimizer.step()
if verbose and (batch_idx + 1) % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item()))