def test()

in model/interpolation_net.py [0:0]


    def test(self, dataset, compute_val_loss=True):
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
        shape_x_out = []
        shape_y_out = []
        points_out = []

        tot_loss_val = 0

        for i, data in enumerate(test_loader):
            shape_x = batch_to_shape(data["X"])
            shape_y = batch_to_shape(data["Y"])

            shape_x, shape_y = self.preprocess(shape_x, shape_y)

            point_pred = self.interp_module.get_pred(shape_x, shape_y)

            if compute_val_loss:
                loss, _ = self.interp_module.compute_loss(shape_x, shape_y, point_pred)
                tot_loss_val += loss.detach() / len(dataset)

            shape_x.detach_cpu()
            shape_y.detach_cpu()
            point_pred = point_pred.detach().cpu()

            points_out.append(point_pred)
            shape_x_out.append(shape_x)
            shape_y_out.append(shape_y)

        if compute_val_loss:
            print("Validation loss = ", tot_loss_val)

        return shape_x_out, shape_y_out, points_out