in model/interpolation_net.py [0:0]
def train(self):
print("start training ...")
self.interp_module.train()
while self.i_epoch < self.interp_module.param.num_it:
tot_loss = 0
tot_loss_comp = None
self.update_settings()
for i, data in enumerate(self.train_loader):
shape_x = batch_to_shape(data["X"])
shape_y = batch_to_shape(data["Y"])
shape_x, shape_y = self.preprocess(shape_x, shape_y)
loss, loss_comp = self.interp_module(shape_x, shape_y)
loss.backward()
if (i + 1) % self.interp_module.param.batch_size == 0 and i < len(
self.train_loader
) - 1:
self.optimizer.step()
self.optimizer.zero_grad()
if tot_loss_comp is None:
tot_loss_comp = [
loss_comp[i].detach() / self.dataset.__len__()
for i in range(len(loss_comp))
]
else:
tot_loss_comp = [
tot_loss_comp[i]
+ loss_comp[i].detach() / self.dataset.__len__()
for i in range(len(loss_comp))
]
tot_loss += loss.detach() / self.dataset.__len__()
self.optimizer.step()
self.optimizer.zero_grad()
print(
"epoch {:04d}, loss = {:.5f} (arap: {:.5f}, reg: {:.5f}, geo: {:.5f}), reserved memory={}MB".format(
self.i_epoch,
tot_loss,
tot_loss_comp[0],
tot_loss_comp[1],
tot_loss_comp[2],
torch.cuda.memory_reserved(0) // (1024 ** 2),
)
)
if self.time_stamp is not None:
if (self.i_epoch + 1) % self.interp_module.param.log_freq == 0:
self.save_self()
if (self.i_epoch + 1) % self.interp_module.param.val_freq == 0:
self.test(self.dataset_val)
self.i_epoch += 1