def generate_sample()

in data/gen_data_collisions.py [0:0]


    def generate_sample(self, l):
        dlength = self.num_particles * 3 + 2
        L = dlength * l
        x = torch.zeros(L, dtype=torch.long)
        y = torch.zeros(L, dtype=torch.long)
        for i in tqdm(range(l)):
            self.move_particles()
            self.check_crossings()
            for p in range(self.num_particles):
                x[i * dlength + 3 * p : i * dlength + 3 * p + 2] = (
                    self.particle_grid_locations[p]
                    + torch.LongTensor((0, self.grid_size))
                )
                x[i * dlength + 3 * p + 2] = (
                    self.particle_colors[p] + 2 * self.grid_size
                )

            if torch.rand(1) < self.easy_q:
                cp, cq = self.most_recent_crossing
            else:
                cp, cq = torch.randint(0, self.num_colors, (2,))

            x[(i + 1) * dlength - 2 : (i + 1) * dlength] = (
                torch.LongTensor((cp, cq))
                + 2 * self.grid_size
                + self.num_colors
            )
            # will give junk in v. beginning, don't count those
            if self.crossing_history[cp, cq, 2] > 0:
                y[(i + 1) * dlength - 1] = (
                    self.bin_location(
                        self.crossing_history[cp, cq, 0],
                        self.crossing_history[cp, cq, 1],
                    )
                    + 1
                )
        return x, y