def train()

in trainers/train_one_dim_subspaces.py [0:0]


def train(models, writer, data_loader, optimizers, criterion, epoch):

    # We consider only a single model here. Multiple models are for ensembles and SWA baselines.
    model = models[0]
    optimizer = optimizers[0]

    if args.num_samples > 1:
        model.apply(lambda m: setattr(m, "return_feats", True))

    model.zero_grad()
    model.train()
    avg_loss = 0.0
    train_loader = data_loader.train_loader

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(args.device), target.to(args.device)

        # num_samples is the number of samples to draw from the subspace for batch.
        # in all experiments in the main paper it is 1.
        if args.num_samples == 1:

            if args.layerwise:
                for m in model.modules():
                    if isinstance(m, nn.Conv2d) or isinstance(
                        m, nn.BatchNorm2d
                    ):
                        alpha = np.random.uniform(0, 1)
                        setattr(m, f"alpha", alpha)
            else:
                alpha = np.random.uniform(0, 1)
                for m in model.modules():
                    if isinstance(m, nn.Conv2d) or isinstance(
                        m, nn.BatchNorm2d
                    ):
                        setattr(m, f"alpha", alpha)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)

        else:
            # Feel free to ignore this part as it is often not used.
            # This corresponds to section B of the appendix, where multiple samples from the subsapce are
            # taken for each batch.
            div = data.size(0) // args.num_samples
            feats = []
            ts = []
            optimizer.zero_grad()

            for sample in range(args.num_samples):

                if args.layerwise:
                    for m in model.modules():
                        if isinstance(m, nn.Conv2d) or isinstance(
                            m, nn.BatchNorm2d
                        ):
                            alpha = np.random.uniform(0, 1)
                            setattr(m, f"alpha", alpha)
                else:
                    alpha = np.random.uniform(0, 1)
                    for m in model.modules():
                        if isinstance(m, nn.Conv2d) or isinstance(
                            m, nn.BatchNorm2d
                        ):
                            setattr(m, f"alpha", alpha)

                output, f = model(data[sample * div : (sample + 1) * div])
                feats.append(f)

                if sample == 0:
                    loss = (
                        criterion(
                            output, target[sample * div : (sample + 1) * div]
                        )
                        / args.num_samples
                    )
                else:
                    loss += (
                        criterion(
                            output, target[sample * div : (sample + 1) * div]
                        )
                        / args.num_samples
                    )

            if args.lamb > 0:
                out = random.sample([i for i in range(args.num_samples)], 2)
                i, j = out[0], out[1]
                fi, fj = feats[i], feats[j]
                ti, tj = ts[i], ts[j]
                loss += (
                    args.fcos_weight
                    * abs(ti - tj)
                    * (
                        (fi * fj).sum().pow(2)
                        / (fi.pow(2).sum() * fj.pow(2).sum())
                    )
                )

        # Application of the regularization term, equation 3.
        num_points = 2 if args.conv_type is "LinesConv" else 3
        if args.beta > 0:
            out = random.sample([i for i in range(num_points)], 2)

            i, j = out[0], out[1]
            num = 0.0
            normi = 0.0
            normj = 0.0
            for m in model.modules():
                if isinstance(m, nn.Conv2d):
                    vi = get_weight(m, i)
                    vj = get_weight(m, j)
                    num += (vi * vj).sum()
                    normi += vi.pow(2).sum()
                    normj += vj.pow(2).sum()
            loss += args.beta * (num.pow(2) / (normi * normj))

        loss.backward()

        optimizer.step()

        avg_loss += loss.item()

        it = len(train_loader) * epoch + batch_idx
        if batch_idx % args.log_interval == 0:
            num_samples = batch_idx * len(data)
            num_epochs = len(train_loader.dataset)
            percent_complete = 100.0 * batch_idx / len(train_loader)
            print(
                f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
                f"Loss: {loss.item():.6f}"
            )

            if args.save:
                writer.add_scalar(f"train/loss", loss.item(), it)
        if args.save and it in args.save_iters:
            utils.save_cpt(epoch, it, models, optimizers, -1, -1)

    model.apply(lambda m: setattr(m, "return_feats", False))

    avg_loss = avg_loss / len(train_loader)
    return avg_loss, optimizers