def train()

in part_generator.py [0:0]


    def train(self):
        assert self.loader_G is not None, 'You must first initialize the data source with `.set_data_src(<folder of images>)`'

        self.init_folders()

        if self.GAN is None:
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.tensor(0.).cuda()
        total_gen_loss = torch.tensor(0.).cuda()

        batch_size = self.batch_size

        image_size = self.GAN.G.image_size
        latent_dim = self.GAN.G.latent_dim
        num_layers = self.GAN.G.num_layers

        apply_gradient_penalty = self.steps % 4 == 0
        apply_path_penalty = self.steps % 32 == 0

        backwards = partial(loss_backwards)

        avg_pl_length = self.pl_mean
        self.GAN.D_opt.zero_grad()

        for i in range(self.gradient_accumulate_every):
            image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_D)]
            image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch
            get_latents_fn = mixed_list if np.random.random() < self.mixed_prob else noise_list
            style = get_latents_fn(batch_size, num_layers, latent_dim)
            noise = image_noise(batch_size, image_size)

            bitmap_feats = self.GAN.Enc(image_cond_batch)

            w_space = latent_to_w(self.GAN.S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats)
            generated_images = torch.max(generated_partial_images, image_partial_batch)

            generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]),
                                                    image_cond_batch[:, self.partid+1:-1], generated_images], 1)
            fake_output = self.GAN.D(generated_image_stack_batch.clone().detach())

            image_batch.requires_grad_()
            real_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(part_only_batch, image_cond_batch[:, self.partid:self.partid+1]),
                                                    image_cond_batch[:, self.partid+1:-1], image_batch], 1)
            real_image_stack_batch.requires_grad_()
            real_output = self.GAN.D(real_image_stack_batch)

            disc_loss = (F.relu(1 + real_output) + F.relu(1 - fake_output)).mean()

            if apply_gradient_penalty:
                gp = gradient_penalty(real_image_stack_batch, real_output)
                self.last_gp_loss = gp.clone().detach().item()
                disc_loss = disc_loss + gp

            disc_loss = disc_loss / self.gradient_accumulate_every
            disc_loss.register_hook(raise_if_nan)
            backwards(disc_loss, self.GAN.D_opt)

            total_disc_loss += disc_loss.detach().item() / self.gradient_accumulate_every

        self.d_loss = float(total_disc_loss)
        self.GAN.D_opt.step()

        # train generator

        self.GAN.G_opt.zero_grad()
        for i in range(self.gradient_accumulate_every):
            image_batch, image_cond_batch, part_only_batch = [item.cuda() for item in next(self.loader_G)]
            image_partial_batch = image_cond_batch[:, -1:, :, :] # take the first one as the entire input partial sketch
            
            style = get_latents_fn(batch_size, num_layers, latent_dim)
            noise = image_noise(batch_size, image_size)

            bitmap_feats = self.GAN.Enc(image_cond_batch)

            w_space = latent_to_w(self.GAN.S, style)
            w_styles = styles_def_to_tensor(w_space)

            generated_partial_images = self.GAN.G(w_styles, noise, bitmap_feats)
            generated_images = torch.max(generated_partial_images, image_partial_batch)
            
            generated_image_stack_batch = torch.cat([image_cond_batch[:, :self.partid], torch.max(generated_partial_images, image_cond_batch[:, self.partid:self.partid+1]),
                                                    image_cond_batch[:, self.partid+1:-1], generated_images], 1)
            fake_output = self.GAN.D(generated_image_stack_batch)

            loss = fake_output.mean()
            gen_loss = loss

            if apply_path_penalty:
                pl_lengths = calc_pl_lengths(w_styles, generated_images)
                avg_pl_length = pl_lengths.detach().mean()

                if not is_empty(self.pl_mean):
                    pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
                    if not torch.isnan(pl_loss):
                        gen_loss = gen_loss + pl_loss
                        if self.similarity_penalty:
                            gen_loss = gen_loss - self.similarity_penalty*(pl_lengths ** 2).mean()

            if self.sparsity_penalty:
                generated_density = generated_partial_images.reshape(self.batch_size, -1).sum(1)
                target_density = part_only_batch.reshape(self.batch_size, -1).sum(1) # if we devide the sketch by parts
                self.sparsity_loss = ((generated_density-target_density)**2).mean()
                gen_loss = gen_loss + self.sparsity_loss*self.sparsity_penalty

            gen_loss = gen_loss / self.gradient_accumulate_every
            gen_loss.register_hook(raise_if_nan)
            backwards(gen_loss, self.GAN.G_opt)

            total_gen_loss += loss.detach().item() / self.gradient_accumulate_every

        self.g_loss = float(total_gen_loss)
        self.GAN.G_opt.step()

        # calculate moving averages

        if apply_path_penalty and not torch.isnan(avg_pl_length):
            ema_inplace(self.pl_mean, avg_pl_length, self.pl_ema_decay)
            self.pl_loss = self.pl_mean.item()

        # save from NaN errors

        checkpoint_num = floor(self.steps / self.save_every)

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(f'NaN detected for generator or discriminator. Loading from checkpoint #{checkpoint_num}')
            self.load(checkpoint_num)
            raise NanException

        # periodically save results

        if self.steps % self.save_every == 0:
            self.save(checkpoint_num)

        if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500):
            self.evaluate(floor(self.steps / 1000))

        self.steps += 1
        self.av = None