in svg/agent.py [0:0]
def update(self, replay_buffer, logger, step):
self.last_step = step
if step % self.update_freq != 0:
return
if (self.horizon > 1 or not self.critic) and \
(step % self.model_update_freq == 0) and \
(self.actor_mve or self.critic_target_mve):
for i in range(self.model_update_repeat):
obses, actions, rewards = replay_buffer.sample_multistep(
self.seq_batch_size, self.seq_train_length)
assert obses.ndimension() == 3
dx_loss = self.dx.update_step(obses, actions, rewards, logger, step)
if self.actor_dx_threshold is not None:
if self.rolling_dx_loss is None:
self.rolling_dx_loss = dx_loss
else:
factor = 0.9
self.rolling_dx_loss = factor*self.rolling_dx_loss + \
(1.-factor)*dx_loss
n_updates = 1 if step < self.warmup_steps else self.model_free_update_repeat
for i in range(n_updates):
obs, action, reward, next_obs, not_done, not_done_no_max = \
replay_buffer.sample(self.step_batch_size)
if self.critic is not None:
if self.full_target_mve:
self.update_critic_mve(obs, action, reward, next_obs, not_done_no_max, logger, step)
else:
self.update_critic(
obs, next_obs,
action, reward, not_done_no_max, logger, step
)
if step % self.actor_update_freq == 0:
self.update_actor_and_alpha(obs, logger, step)
if self.rew_opt is not None:
self.update_rew_step(obs, action, reward, logger, step)
self.update_done_step(obs, action, not_done_no_max, logger, step)
if self.critic is not None and step % self.critic_target_update_freq == 0:
utils.soft_update_params(
self.critic, self.critic_target, self.critic_tau)