def train()

in BigGAN_PyTorch/train_fns.py [0:0]


    def train(x, y=None, features=None):
        if embedded_optimizers:
            G.optim.zero_grad()
            D.optim.zero_grad()
        else:
            GD.optimizer_D.zero_grad()
            GD.optimizer_G.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, batch_size)
        if y is not None:
            y = torch.split(y, batch_size)
        if features is not None:
            f_ = torch.split(features, batch_size)
        else:
            f_ = None
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config["toggle_grads"]:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config["num_D_steps"]):
            # If accumulating gradients, loop multiple times before an optimizer step
            if embedded_optimizers:
                D.optim.zero_grad()
            else:
                GD.optimizer_D.zero_grad()
            for accumulation_index in range(config["num_D_accumulations"]):
                # Sample conditioning for G
                sampled_cond = sample_conditionings()
                labels_g, f_g = None, None
                if features is not None and y is not None:
                    z_, labels_g, f_g = sampled_cond
                elif y is not None:
                    z_, labels_g = sampled_cond
                elif features is not None:
                    z_, f_g = sampled_cond
                # Tensors to device
                if labels_g is not None:
                    labels_g = (
                        labels_g[:batch_size].to(device, non_blocking=True).long()
                    )
                if f_g is not None:
                    f_g = f_g[:batch_size].to(device, non_blocking=True)
                z_ = z_[:batch_size].to(device, non_blocking=True)
                # Obtain discriminator scores
                D_fake, D_real = GD(
                    z_,
                    labels_g,
                    f_g,
                    x[counter],
                    y[counter] if y is not None else None,
                    f_[counter] if f_ is not None else None,
                    train_G=False,
                    split_D=config["split_D"],
                    policy=config["DiffAugment"],
                    DA=config["DA"],
                )

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
                D_loss = (D_loss_real + D_loss_fake) / float(
                    config["num_D_accumulations"]
                )
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config["D_ortho"] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print("using modified ortho reg in D")
                utils.ortho(D, config["D_ortho"])

            if embedded_optimizers:
                D.optim.step()
            else:
                GD.optimizer_D.step()

        # Optionally toggle "requires_grad"
        if config["toggle_grads"]:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        if embedded_optimizers:
            G.optim.zero_grad()
        else:
            GD.optimizer_G.zero_grad()

        counter = 0
        # If accumulating gradients, loop multiple times
        for accumulation_index in range(config["num_G_accumulations"]):
            # Sample conditioning for G
            sampled_cond = sample_conditionings()
            labels_g, f_g = None, None
            if features is not None and y is not None:
                z_, labels_g, f_g = sampled_cond
            elif y is not None:
                z_, labels_g = sampled_cond
            elif features is not None:
                z_, f_g = sampled_cond
            # Tensors to device
            if labels_g is not None:
                labels_g = labels_g.to(device, non_blocking=True).long()
            if f_g is not None:
                f_g = f_g.to(device, non_blocking=True)
            z_ = z_.to(device, non_blocking=True)
            # Obtain discriminator scores
            D_fake = GD(
                z_,
                labels_g,
                f_g,
                train_G=True,
                split_D=config["split_D"],
                policy=config["DiffAugment"],
                DA=config["DA"],
            )
            G_loss = losses.generator_loss(D_fake) / float(
                config["num_G_accumulations"]
            )
            G_loss.backward()
            counter += 1

        # Optionally apply modified ortho reg in G
        if config["G_ortho"] > 0.0:
            print(
                "using modified ortho reg in G"
            )  # Debug print to indicate we're using ortho reg in G
            # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
            utils.ortho(
                G,
                config["G_ortho"],
                blacklist=[param for param in G.shared.parameters()],
            )
        if embedded_optimizers:
            G.optim.step()
        else:
            GD.optimizer_G.step()

        # If we have an ema, update it, regardless of if we test with it or not
        if config["ema"]:
            ema.update(state_dict["itr"])

        out = {
            "G_loss": float(G_loss.item()),
            "D_loss_real": float(D_loss_real.item()),
            "D_loss_fake": float(D_loss_fake.item()),
        }
        # Return G's loss and the components of D's loss.
        return out