in utils/interpolation_base.py [0:0]
def forward(self):
num_t = self.param.num_timesteps
E_x = self.energy.forward_single(
self.vert_sequence[0, ...].detach(),
self.vert_sequence[1, ...],
self.shape_x,
)
E_y = self.energy.forward_single(
self.vert_sequence[num_t - 2, ...],
self.vert_sequence[num_t - 1, ...].detach(),
self.shape_x,
)
E_total = E_x + E_y
for i in range(1, num_t - 2):
E_curr = self.energy.forward_single(
self.vert_sequence[i, ...], self.vert_sequence[i + 1, ...], self.shape_x
)
E_total = E_total + E_curr
E_total = E_total / (num_t - 1)
return E_total, [E_total]