def train()

in trainers/train_simplexes.py [0:0]


def train(models, writer, data_loader, optimizers, criterion, epoch):

    model = models[0]
    optimizer = optimizers[0]

    model.zero_grad()
    model.train()
    avg_loss = 0.0
    train_loader = data_loader.train_loader

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(args.device), target.to(args.device)

        # To sample from a simplex, sample from an exponential distribution then renormalize.
        if args.layerwise:
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    Z = np.random.exponential(scale=1.0, size=args.n)
                    Z = Z / Z.sum()
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])
        else:
            Z = np.random.exponential(scale=1.0, size=args.n)
            Z = Z / Z.sum()
            for m in model.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
                    for i in range(1, args.n):
                        setattr(m, f"t{i}", Z[i])

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        if args.beta > 0:
            out = random.sample([i for i in range(args.n)], 2)
            i, j = out[0], out[1]
            num = 0.0
            normi = 0.0
            normj = 0.0
            for m in model.modules():
                if isinstance(m, nn.Conv2d):
                    vi = get_weight(m, i)
                    vj = get_weight(m, j)
                    num += (vi * vj).sum()
                    normi += vi.pow(2).sum()
                    normj += vj.pow(2).sum()
            loss += args.beta * (num.pow(2) / (normi * normj))

        loss.backward()

        optimizer.step()

        avg_loss += loss.item()

        it = len(train_loader) * epoch + batch_idx
        if batch_idx % args.log_interval == 0:
            num_samples = batch_idx * len(data)
            num_epochs = len(train_loader.dataset)
            percent_complete = 100.0 * batch_idx / len(train_loader)
            print(
                f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
                f"Loss: {loss.item():.6f}"
            )

            if args.save:
                writer.add_scalar(f"train/loss", loss.item(), it)
        if args.save and it in args.save_iters:
            utils.save_cpt(epoch, it, models, optimizers, -1, -1)

    avg_loss = avg_loss / len(train_loader)
    return avg_loss, optimizers