in part_selector.py [0:0]
def train(self):
self.init_folders()
if self.clf is None:
self.init_clf()
self.clf.train()
total_disc_loss = torch.tensor(0.).cuda()
total_acc = torch.tensor(0.).cuda()
batch_size = self.batch_size
backwards = partial(loss_backwards)
self.clf.D_opt.zero_grad()
for i in range(self.gradient_accumulate_every):
part_id_batch, image_cond_batch, _ = [item.cuda() for item in next(self.loader)]
outputs = self.clf.D(image_cond_batch)
_, predicts = torch.max(outputs, 1)
acc = (predicts == part_id_batch).sum().float() / part_id_batch.size(0) / self.gradient_accumulate_every
disc_loss = self.criterion(outputs, part_id_batch)
disc_loss = disc_loss / self.gradient_accumulate_every
disc_loss.register_hook(raise_if_nan)
backwards(disc_loss, self.clf.D_opt)
total_disc_loss += disc_loss.detach().item()
total_acc += acc.detach().item()
self.d_loss = float(total_disc_loss)
self.d_acc = float(total_acc)
self.clf.D_opt.step()
# save from NaN errors
checkpoint_num = floor(self.steps / self.save_every)
if torch.isnan(total_disc_loss):
print(f'NaN detected. Loading from checkpoint #{checkpoint_num}')
self.load(checkpoint_num)
raise NanException
# periodically save results
if self.steps % self.save_every == 0:
self.save(checkpoint_num)
if self.steps % 1000 == 0 or (self.steps % 100 == 0 and self.steps < 2500):
self.evaluate(floor(self.steps / 1000))
self.steps += 1