in data/gen_data_collisions.py [0:0]
def generate_sample(self, l):
dlength = self.num_particles * 3 + 2
L = dlength * l
x = torch.zeros(L, dtype=torch.long)
y = torch.zeros(L, dtype=torch.long)
for i in tqdm(range(l)):
self.move_particles()
self.check_crossings()
for p in range(self.num_particles):
x[i * dlength + 3 * p : i * dlength + 3 * p + 2] = (
self.particle_grid_locations[p]
+ torch.LongTensor((0, self.grid_size))
)
x[i * dlength + 3 * p + 2] = (
self.particle_colors[p] + 2 * self.grid_size
)
if torch.rand(1) < self.easy_q:
cp, cq = self.most_recent_crossing
else:
cp, cq = torch.randint(0, self.num_colors, (2,))
x[(i + 1) * dlength - 2 : (i + 1) * dlength] = (
torch.LongTensor((cp, cq))
+ 2 * self.grid_size
+ self.num_colors
)
# will give junk in v. beginning, don't count those
if self.crossing_history[cp, cq, 2] > 0:
y[(i + 1) * dlength - 1] = (
self.bin_location(
self.crossing_history[cp, cq, 0],
self.crossing_history[cp, cq, 1],
)
+ 1
)
return x, y