def train()

in train_func.py [0:0]


def train(args, extr, clf, loss_fn, device, train_loader, optimizer, epoch, verbose=True):
    if extr is not None:
        extr.train()
    clf.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if extr is not None:
            output = clf(extr(data))
            if len(output) == 3:
                output = output[0]
        else:
            output = clf(data)
        loss = loss_fn(output, target)
        if args.lam > 0:
            if extr is not None:
                loss += args.lam * params_to_vec(extr.parameters()).pow(2).sum() / 2
            loss += args.lam * params_to_vec(clf.parameters()).pow(2).sum() / 2
        loss.backward()
        optimizer.step()
        if verbose and (batch_idx + 1) % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), loss.item()))