in rlmeta/agents/dqn/apex_dqn_agent.py [0:0]
def train(self, num_steps: int) -> Optional[StatsDict]:
self.controller.set_phase(Phase.TRAIN, reset=True)
self.replay_buffer.warm_up()
stats = StatsDict()
for step in range(num_steps):
t0 = time.perf_counter()
batch, weight, index = self.replay_buffer.sample(self.batch_size)
t1 = time.perf_counter()
step_stats = self.train_step(batch, weight, index)
t2 = time.perf_counter()
time_stats = {
"sample_data_time/ms": (t1 - t0) * 1000.0,
"batch_learn_time/ms": (t2 - t1) * 1000.0,
}
stats.add_dict(step_stats)
stats.add_dict(time_stats)
if step % self.sync_every_n_steps == self.sync_every_n_steps - 1:
self.model.sync_target_net()
if step % self.push_every_n_steps == self.push_every_n_steps - 1:
self.model.push()
episode_stats = self.controller.get_stats()
stats.update(episode_stats)
return stats