in drqv2.py [0:0]
def update(self, replay_iter, step):
metrics = dict()
if step % self.update_every_steps != 0:
return metrics
batch = next(replay_iter)
obs, action, reward, discount, next_obs = utils.to_torch(
batch, self.device)
# augment
obs = self.aug(obs.float())
next_obs = self.aug(next_obs.float())
# encode
obs = self.encoder(obs)
with torch.no_grad():
next_obs = self.encoder(next_obs)
if self.use_tb:
metrics['batch_reward'] = reward.mean().item()
# update critic
metrics.update(
self.update_critic(obs, action, reward, discount, next_obs, step))
# update actor
metrics.update(self.update_actor(obs.detach(), step))
# update critic target
utils.soft_update_params(self.critic, self.critic_target,
self.critic_target_tau)
return metrics