def train()

in rlmeta/agents/dqn/apex_dqn_agent.py [0:0]


    def train(self, num_steps: int) -> Optional[StatsDict]:
        self.controller.set_phase(Phase.TRAIN, reset=True)

        self.replay_buffer.warm_up()
        stats = StatsDict()
        for step in range(num_steps):
            t0 = time.perf_counter()
            batch, weight, index = self.replay_buffer.sample(self.batch_size)
            t1 = time.perf_counter()
            step_stats = self.train_step(batch, weight, index)
            t2 = time.perf_counter()
            time_stats = {
                "sample_data_time/ms": (t1 - t0) * 1000.0,
                "batch_learn_time/ms": (t2 - t1) * 1000.0,
            }
            stats.add_dict(step_stats)
            stats.add_dict(time_stats)

            if step % self.sync_every_n_steps == self.sync_every_n_steps - 1:
                self.model.sync_target_net()

            if step % self.push_every_n_steps == self.push_every_n_steps - 1:
                self.model.push()

        episode_stats = self.controller.get_stats()
        stats.update(episode_stats)

        return stats