def step()

in distractors/n_body_problem.py [0:0]


    def step(self):

        # Helper functions since ode solver requires flattened inputs
        def flatten(positions, velocities):  # positions shape (N, D); velocities shape (N, D)
            system_state = np.concatenate((positions, velocities), axis=1)  # (N, 2D)
            system_state_flat = system_state.flatten()  # ode solver requires flat, (N*2D,)
            return system_state_flat

        def unflatten(system_state_flat):  # system_state_flat shape (N*2*D,)
            system_state = system_state_flat.reshape(self.num_bodies, 2 * self.num_dimensions)  # (N, 2*D)
            positions = system_state[:, :self.num_dimensions]  # (N, D)
            velocities = system_state[:, self.num_dimensions:]  # (N, D)
            return positions, velocities

        # ODE function
        def system_first_order_ode(system_state_flat, _):

            positions, velocities = unflatten(system_state_flat)
            accelerations = np.zeros_like(velocities)  # init (N, D)

            for i in range(self.num_bodies):
                relative_positions = positions - positions[i]  # (N, D)
                distances = np.linalg.norm(relative_positions, axis=1, keepdims=True)  # (N, 1)
                distances[i] = 1.  # bodies don't affect themselves, and we don't want to divide by zero next

                # forces (see https://en.wikipedia.org/wiki/Numerical_model_of_the_Solar_System)
                force_vectors = self.GRAVITATIONAL_CONSTANT * relative_positions / (distances**self.num_dimensions)  # (N,D)
                force_vector = np.sum(force_vectors, axis=0)  # (D,)
                accelerations[i] = force_vector  # assuming mass 1.

            d_system_state_flat = flatten(velocities, accelerations)
            return d_system_state_flat

        # integrate + update
        current_system_state_flat = flatten(self.body_positions, self.body_velocities)  # (N*2*D,)
        _, next_system_state_flat = odeint(system_first_order_ode, current_system_state_flat, [0., self.dt])  # (N*2*D,)
        self.body_positions, self.body_velocities = unflatten(next_system_state_flat)  # (N, D), (N, D)

        # bounce off boundaries of box
        if self.contained_in_a_box:
            ind_below_min = self.body_positions < self.MIN_POS
            ind_above_max = self.body_positions > self.MAX_POS
            self.body_positions[ind_below_min] += 2. * (self.MIN_POS - self.body_positions[ind_below_min])
            self.body_positions[ind_above_max] += 2. * (self.MAX_POS - self.body_positions[ind_above_max])
            self.body_velocities[ind_below_min] *= -1.
            self.body_velocities[ind_above_max] *= -1.
            self.assert_bodies_in_box()  # check for bugs