in BigGAN_PyTorch/train_fns.py [0:0]
def train(x, y=None, features=None):
if embedded_optimizers:
G.optim.zero_grad()
D.optim.zero_grad()
else:
GD.optimizer_D.zero_grad()
GD.optimizer_G.zero_grad()
# How many chunks to split x and y into?
x = torch.split(x, batch_size)
if y is not None:
y = torch.split(y, batch_size)
if features is not None:
f_ = torch.split(features, batch_size)
else:
f_ = None
counter = 0
# Optionally toggle D and G's "require_grad"
if config["toggle_grads"]:
utils.toggle_grad(D, True)
utils.toggle_grad(G, False)
for step_index in range(config["num_D_steps"]):
# If accumulating gradients, loop multiple times before an optimizer step
if embedded_optimizers:
D.optim.zero_grad()
else:
GD.optimizer_D.zero_grad()
for accumulation_index in range(config["num_D_accumulations"]):
# Sample conditioning for G
sampled_cond = sample_conditionings()
labels_g, f_g = None, None
if features is not None and y is not None:
z_, labels_g, f_g = sampled_cond
elif y is not None:
z_, labels_g = sampled_cond
elif features is not None:
z_, f_g = sampled_cond
# Tensors to device
if labels_g is not None:
labels_g = (
labels_g[:batch_size].to(device, non_blocking=True).long()
)
if f_g is not None:
f_g = f_g[:batch_size].to(device, non_blocking=True)
z_ = z_[:batch_size].to(device, non_blocking=True)
# Obtain discriminator scores
D_fake, D_real = GD(
z_,
labels_g,
f_g,
x[counter],
y[counter] if y is not None else None,
f_[counter] if f_ is not None else None,
train_G=False,
split_D=config["split_D"],
policy=config["DiffAugment"],
DA=config["DA"],
)
# Compute components of D's loss, average them, and divide by
# the number of gradient accumulations
D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real)
D_loss = (D_loss_real + D_loss_fake) / float(
config["num_D_accumulations"]
)
D_loss.backward()
counter += 1
# Optionally apply ortho reg in D
if config["D_ortho"] > 0.0:
# Debug print to indicate we're using ortho reg in D.
print("using modified ortho reg in D")
utils.ortho(D, config["D_ortho"])
if embedded_optimizers:
D.optim.step()
else:
GD.optimizer_D.step()
# Optionally toggle "requires_grad"
if config["toggle_grads"]:
utils.toggle_grad(D, False)
utils.toggle_grad(G, True)
# Zero G's gradients by default before training G, for safety
if embedded_optimizers:
G.optim.zero_grad()
else:
GD.optimizer_G.zero_grad()
counter = 0
# If accumulating gradients, loop multiple times
for accumulation_index in range(config["num_G_accumulations"]):
# Sample conditioning for G
sampled_cond = sample_conditionings()
labels_g, f_g = None, None
if features is not None and y is not None:
z_, labels_g, f_g = sampled_cond
elif y is not None:
z_, labels_g = sampled_cond
elif features is not None:
z_, f_g = sampled_cond
# Tensors to device
if labels_g is not None:
labels_g = labels_g.to(device, non_blocking=True).long()
if f_g is not None:
f_g = f_g.to(device, non_blocking=True)
z_ = z_.to(device, non_blocking=True)
# Obtain discriminator scores
D_fake = GD(
z_,
labels_g,
f_g,
train_G=True,
split_D=config["split_D"],
policy=config["DiffAugment"],
DA=config["DA"],
)
G_loss = losses.generator_loss(D_fake) / float(
config["num_G_accumulations"]
)
G_loss.backward()
counter += 1
# Optionally apply modified ortho reg in G
if config["G_ortho"] > 0.0:
print(
"using modified ortho reg in G"
) # Debug print to indicate we're using ortho reg in G
# Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
utils.ortho(
G,
config["G_ortho"],
blacklist=[param for param in G.shared.parameters()],
)
if embedded_optimizers:
G.optim.step()
else:
GD.optimizer_G.step()
# If we have an ema, update it, regardless of if we test with it or not
if config["ema"]:
ema.update(state_dict["itr"])
out = {
"G_loss": float(G_loss.item()),
"D_loss_real": float(D_loss_real.item()),
"D_loss_fake": float(D_loss_fake.item()),
}
# Return G's loss and the components of D's loss.
return out