def train()

in model/interpolation_net.py [0:0]


    def train(self):
        print("start training ...")

        self.interp_module.train()

        while self.i_epoch < self.interp_module.param.num_it:
            tot_loss = 0
            tot_loss_comp = None

            self.update_settings()

            for i, data in enumerate(self.train_loader):
                shape_x = batch_to_shape(data["X"])
                shape_y = batch_to_shape(data["Y"])

                shape_x, shape_y = self.preprocess(shape_x, shape_y)

                loss, loss_comp = self.interp_module(shape_x, shape_y)

                loss.backward()

                if (i + 1) % self.interp_module.param.batch_size == 0 and i < len(
                    self.train_loader
                ) - 1:
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                if tot_loss_comp is None:
                    tot_loss_comp = [
                        loss_comp[i].detach() / self.dataset.__len__()
                        for i in range(len(loss_comp))
                    ]
                else:
                    tot_loss_comp = [
                        tot_loss_comp[i]
                        + loss_comp[i].detach() / self.dataset.__len__()
                        for i in range(len(loss_comp))
                    ]

                tot_loss += loss.detach() / self.dataset.__len__()

            self.optimizer.step()
            self.optimizer.zero_grad()

            print(
                "epoch {:04d}, loss = {:.5f} (arap: {:.5f}, reg: {:.5f}, geo: {:.5f}), reserved memory={}MB".format(
                    self.i_epoch,
                    tot_loss,
                    tot_loss_comp[0],
                    tot_loss_comp[1],
                    tot_loss_comp[2],
                    torch.cuda.memory_reserved(0) // (1024 ** 2),
                )
            )

            if self.time_stamp is not None:
                if (self.i_epoch + 1) % self.interp_module.param.log_freq == 0:
                    self.save_self()
                if (self.i_epoch + 1) % self.interp_module.param.val_freq == 0:
                    self.test(self.dataset_val)

            self.i_epoch += 1