def update()

in svg/agent.py [0:0]


    def update(self, replay_buffer, logger, step):
        self.last_step = step
        if step % self.update_freq != 0:
            return

        if (self.horizon > 1 or not self.critic) and \
              (step % self.model_update_freq == 0) and \
              (self.actor_mve or self.critic_target_mve):
            for i in range(self.model_update_repeat):
                obses, actions, rewards = replay_buffer.sample_multistep(
                    self.seq_batch_size, self.seq_train_length)
                assert obses.ndimension() == 3
                dx_loss = self.dx.update_step(obses, actions, rewards, logger, step)
                if self.actor_dx_threshold is not None:
                    if self.rolling_dx_loss is None:
                        self.rolling_dx_loss = dx_loss
                    else:
                        factor = 0.9
                        self.rolling_dx_loss = factor*self.rolling_dx_loss + \
                          (1.-factor)*dx_loss

        n_updates = 1 if step < self.warmup_steps else self.model_free_update_repeat
        for i in range(n_updates):
            obs, action, reward, next_obs, not_done, not_done_no_max = \
              replay_buffer.sample(self.step_batch_size)

            if self.critic is not None:
                if self.full_target_mve:
                    self.update_critic_mve(obs, action, reward, next_obs, not_done_no_max, logger, step)
                else:
                    self.update_critic(
                        obs, next_obs,
                        action, reward, not_done_no_max, logger, step
                    )

            if step % self.actor_update_freq == 0:
                self.update_actor_and_alpha(obs, logger, step)

            if self.rew_opt is not None:
                self.update_rew_step(obs, action, reward, logger, step)

            self.update_done_step(obs, action, not_done_no_max, logger, step)

            if self.critic is not None and step % self.critic_target_update_freq == 0:
                utils.soft_update_params(
                    self.critic, self.critic_target, self.critic_tau)