in distractors/n_body_problem.py [0:0]
def step(self):
# Helper functions since ode solver requires flattened inputs
def flatten(positions, velocities): # positions shape (N, D); velocities shape (N, D)
system_state = np.concatenate((positions, velocities), axis=1) # (N, 2D)
system_state_flat = system_state.flatten() # ode solver requires flat, (N*2D,)
return system_state_flat
def unflatten(system_state_flat): # system_state_flat shape (N*2*D,)
system_state = system_state_flat.reshape(self.num_bodies, 2 * self.num_dimensions) # (N, 2*D)
positions = system_state[:, :self.num_dimensions] # (N, D)
velocities = system_state[:, self.num_dimensions:] # (N, D)
return positions, velocities
# ODE function
def system_first_order_ode(system_state_flat, _):
positions, velocities = unflatten(system_state_flat)
accelerations = np.zeros_like(velocities) # init (N, D)
for i in range(self.num_bodies):
relative_positions = positions - positions[i] # (N, D)
distances = np.linalg.norm(relative_positions, axis=1, keepdims=True) # (N, 1)
distances[i] = 1. # bodies don't affect themselves, and we don't want to divide by zero next
# forces (see https://en.wikipedia.org/wiki/Numerical_model_of_the_Solar_System)
force_vectors = self.GRAVITATIONAL_CONSTANT * relative_positions / (distances**self.num_dimensions) # (N,D)
force_vector = np.sum(force_vectors, axis=0) # (D,)
accelerations[i] = force_vector # assuming mass 1.
d_system_state_flat = flatten(velocities, accelerations)
return d_system_state_flat
# integrate + update
current_system_state_flat = flatten(self.body_positions, self.body_velocities) # (N*2*D,)
_, next_system_state_flat = odeint(system_first_order_ode, current_system_state_flat, [0., self.dt]) # (N*2*D,)
self.body_positions, self.body_velocities = unflatten(next_system_state_flat) # (N, D), (N, D)
# bounce off boundaries of box
if self.contained_in_a_box:
ind_below_min = self.body_positions < self.MIN_POS
ind_above_max = self.body_positions > self.MAX_POS
self.body_positions[ind_below_min] += 2. * (self.MIN_POS - self.body_positions[ind_below_min])
self.body_positions[ind_above_max] += 2. * (self.MAX_POS - self.body_positions[ind_above_max])
self.body_velocities[ind_below_min] *= -1.
self.body_velocities[ind_above_max] *= -1.
self.assert_bodies_in_box() # check for bugs