def train()

in part_selector.py [0:0]


    def train(self):
        self.init_folders()
        if self.clf is None:
            self.init_clf()

        self.clf.train()
        total_disc_loss = torch.tensor(0.).cuda()
        total_acc = torch.tensor(0.).cuda()
        batch_size = self.batch_size

        backwards = partial(loss_backwards)

        self.clf.D_opt.zero_grad()

        for i in range(self.gradient_accumulate_every):
            part_id_batch, image_cond_batch, _ = [item.cuda() for item in next(self.loader)]
            outputs = self.clf.D(image_cond_batch)
            _, predicts = torch.max(outputs, 1)
            acc = (predicts == part_id_batch).sum().float() / part_id_batch.size(0) / self.gradient_accumulate_every
            disc_loss = self.criterion(outputs, part_id_batch)
            disc_loss = disc_loss / self.gradient_accumulate_every
            disc_loss.register_hook(raise_if_nan)
            backwards(disc_loss, self.clf.D_opt)
            total_disc_loss += disc_loss.detach().item()
            total_acc += acc.detach().item()

        self.d_loss = float(total_disc_loss)
        self.d_acc = float(total_acc)
        self.clf.D_opt.step()

        # save from NaN errors

        checkpoint_num = floor(self.steps / self.save_every)

        if torch.isnan(total_disc_loss):
            print(f'NaN detected. Loading from checkpoint #{checkpoint_num}')
            self.load(checkpoint_num)
            raise NanException

        # periodically save results
        if self.steps % self.save_every == 0:
            self.save(checkpoint_num)

        if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500):
            self.evaluate(floor(self.steps / 1000))

        self.steps += 1