def forward()

in utils/interpolation_base.py [0:0]


    def forward(self):
        num_t = self.param.num_timesteps

        v_s = self.get_vert_sequence()

        E_x = self.energy.forward_single(
            v_s[0, ...].detach(), v_s[1, ...], self.shape_x
        )
        E_y = self.energy.forward_single(
            v_s[num_t - 2, ...], v_s[num_t - 1, ...].detach(), self.shape_x
        )

        E_total = E_x + E_y

        for i in range(1, num_t - 2):
            E_curr = self.energy.forward_single(
                v_s[i, ...], v_s[i + 1, ...], self.shape_x
            )
            E_total = E_total + E_curr

        E_total = E_total / (num_t - 1)

        return E_total, [E_total]