in utils/interpolation_base.py [0:0]
def forward(self):
num_t = self.param.num_timesteps
v_s = self.get_vert_sequence()
E_x = self.energy.forward_single(
v_s[0, ...].detach(), v_s[1, ...], self.shape_x
)
E_y = self.energy.forward_single(
v_s[num_t - 2, ...], v_s[num_t - 1, ...].detach(), self.shape_x
)
E_total = E_x + E_y
for i in range(1, num_t - 2):
E_curr = self.energy.forward_single(
v_s[i, ...], v_s[i + 1, ...], self.shape_x
)
E_total = E_total + E_curr
E_total = E_total / (num_t - 1)
return E_total, [E_total]