in part_generator.py [0:0]
def train(self):
assert self.loader_G is not None, 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
self.init_folders()
if self.GAN is None:
self.init_GAN()
self.GAN.train()
total_disc_loss = torch.tensor(0.).cuda()
total_gen_loss = torch.tensor(0.).cuda()
batch_size = self.batch_size
image_size = self.GAN.G.image_size
latent_dim = self.GAN.G.latent_dim
num_layers = self.GAN.G.num_layers
apply_gradient_penalty = self.steps % 4 == 0
apply_path_penalty = self.steps % 32 == 0
backwards = partial(loss_backwards)
avg_pl_length = self.pl_mean
self.GAN.D_opt.zero_grad()
for i in range(self.gradient_accumulate_every):
image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_D)]
image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch
get_latents_fn = mixed_list if np.random.random() < self.mixed_prob else noise_list
style = get_latents_fn(batch_size, num_layers, latent_dim)
noise = image_noise(batch_size, image_size)
bitmap_feats = self.GAN.Enc(image_cond_batch)
w_space = latent_to_w(self.GAN.S, style)
w_styles = styles_def_to_tensor(w_space)
generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats)
generated_images = torch.max(generated_partial_images, image_partial_batch)
generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]),
image_cond_batch[:, self.partid+1:-1], generated_images], 1)
fake_output = self.GAN.D(generated_image_stack_batch.clone().detach())
image_batch.requires_grad_()
real_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(part_only_batch, image_cond_batch[:, self.partid:self.partid+1]),
image_cond_batch[:, self.partid+1:-1], image_batch], 1)
real_image_stack_batch.requires_grad_()
real_output = self.GAN.D(real_image_stack_batch)
disc_loss = (F.relu(1 + real_output) + F.relu(1 - fake_output)).mean()
if apply_gradient_penalty:
gp = gradient_penalty(real_image_stack_batch, real_output)
self.last_gp_loss = gp.clone().detach().item()
disc_loss = disc_loss + gp
disc_loss = disc_loss / self.gradient_accumulate_every
disc_loss.register_hook(raise_if_nan)
backwards(disc_loss, self.GAN.D_opt)
total_disc_loss += disc_loss.detach().item() / self.gradient_accumulate_every
self.d_loss = float(total_disc_loss)
self.GAN.D_opt.step()
# train generator
self.GAN.G_opt.zero_grad()
for i in range(self.gradient_accumulate_every):
image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_G)]
image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch
style = get_latents_fn(batch_size, num_layers, latent_dim)
noise = image_noise(batch_size, image_size)
bitmap_feats = self.GAN.Enc(image_cond_batch)
w_space = latent_to_w(self.GAN.S, style)
w_styles = styles_def_to_tensor(w_space)
generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats)
generated_images = torch.max(generated_partial_images, image_partial_batch)
generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]),
image_cond_batch[:, self.partid+1:-1], generated_images], 1)
fake_output = self.GAN.D(generated_image_stack_batch)
loss = fake_output.mean()
gen_loss = loss
if apply_path_penalty:
pl_lengths = calc_pl_lengths(w_styles, generated_images)
avg_pl_length = pl_lengths.detach().mean()
if not is_empty(self.pl_mean):
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
if not torch.isnan(pl_loss):
gen_loss = gen_loss + pl_loss
if self.similarity_penalty:
gen_loss = gen_loss - self.similarity_penalty*(pl_lengths ** 2).mean()
if self.sparsity_penalty:
generated_density = generated_partial_images.reshape(self.batch_size, -1).sum(1)
target_density = part_only_batch.reshape(self.batch_size, -1).sum(1) # if we devide the sketch by parts
self.sparsity_loss = ((generated_density-target_density)**2).mean()
gen_loss = gen_loss + self.sparsity_loss*self.sparsity_penalty
gen_loss = gen_loss / self.gradient_accumulate_every
gen_loss.register_hook(raise_if_nan)
backwards(gen_loss, self.GAN.G_opt)
total_gen_loss += loss.detach().item() / self.gradient_accumulate_every
self.g_loss = float(total_gen_loss)
self.GAN.G_opt.step()
# calculate moving averages
if apply_path_penalty and not torch.isnan(avg_pl_length):
ema_inplace(self.pl_mean, avg_pl_length, self.pl_ema_decay)
self.pl_loss = self.pl_mean.item()
# save from NaN errors
checkpoint_num = floor(self.steps / self.save_every)
if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
print(f'NaN detected for generator or discriminator. Loading from checkpoint #{checkpoint_num}')
self.load(checkpoint_num)
raise NanException
# periodically save results
if self.steps % self.save_every == 0:
self.save(checkpoint_num)
if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500):
self.evaluate(floor(self.steps / 1000))
self.steps += 1
self.av = None