in model/interpolation_net.py [0:0]
def test(self, dataset, compute_val_loss=True):
test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
shape_x_out = []
shape_y_out = []
points_out = []
tot_loss_val = 0
for i, data in enumerate(test_loader):
shape_x = batch_to_shape(data["X"])
shape_y = batch_to_shape(data["Y"])
shape_x, shape_y = self.preprocess(shape_x, shape_y)
point_pred = self.interp_module.get_pred(shape_x, shape_y)
if compute_val_loss:
loss, _ = self.interp_module.compute_loss(shape_x, shape_y, point_pred)
tot_loss_val += loss.detach() / len(dataset)
shape_x.detach_cpu()
shape_y.detach_cpu()
point_pred = point_pred.detach().cpu()
points_out.append(point_pred)
shape_x_out.append(shape_x)
shape_y_out.append(shape_y)
if compute_val_loss:
print("Validation loss = ", tot_loss_val)
return shape_x_out, shape_y_out, points_out