def _update()

in hucc/agents/sacse.py [0:0]


    def _update(self):
        for p in self._model.parameters():
            mdevice = p.device
            break

        def act_logp(obs):
            dist = self._model.pi(obs)
            action = dist.rsample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            action = action * self._action_factor
            return action, log_prob

        # We'll equally feed samples to all models
        idx = th.arange(self._bsz * self._n_models) // self._bsz
        idx_in = F.one_hot(idx, self._n_models).to(
            dtype=th.float32, device=mdevice
        )

        for _ in range(self._num_updates):
            batch = self._buffer.get_batch(
                self._bsz * self._n_models,
                device=mdevice,
            )
            reward = batch['reward']
            not_done = th.logical_not(batch['terminal'])
            obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            obs['task'] = idx_in
            obs_p['task'] = idx_in
            alpha = (
                self._log_alpha.detach()
                .exp()
                .to(dtype=th.float32, device=mdevice)[idx]
            )

            # Backup for Q-Function
            with th.no_grad():
                a_p, log_prob_p = act_logp(obs_p)
                q_in = dict(action=a_p, **obs_p)
                q_tgt = th.min(self._q_tgt(q_in), dim=-1).values
                backup = reward + self._gamma * not_done * (
                    q_tgt - alpha * log_prob_p
                )

            # Q-Function update
            q_in = dict(action=batch['action'], **obs)
            q = self._q(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup)
            q2_loss = F.mse_loss(q2, backup)
            q_loss = q1_loss + q2_loss
            self._optim.q.zero_grad()
            q_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.q.parameters(), self._clip_grad_norm
                )
            self._optim.q.step()

            # Policy update
            for param in self._model.q.parameters():
                param.requires_grad_(False)

            a, log_prob = act_logp(obs)
            q_in = dict(action=a, **obs)
            q = th.min(self._q(q_in), dim=-1).values
            pi_loss = (alpha * log_prob - q).mean()
            self._optim.pi.zero_grad()
            pi_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.pi[0].parameters(), self._clip_grad_norm
                )
            self._optim.pi.step()

            for param in self._model.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha:
                alpha_loss = -(
                    self._log_alpha.exp()[idx]
                    * (log_prob.cpu() + self._target_entropy).detach()
                )
                self._optim_alpha.zero_grad()
                alpha_loss.mean().backward()
                self._optim_alpha.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(
                    self._target.q.parameters(), self._model.q.parameters()
                ):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

        # These are the stats for the last update
        self.tbw_add_scalar('Loss/Policy', pi_loss.item())
        self.tbw_add_scalar('Loss/QValue', q_loss.item())
        self.tbw_add_scalar('Health/Entropy', -log_prob.mean())
        if self._optim_alpha:
            self.tbw_add_scalar(
                'Health/Alpha', self._log_alpha.exp().mean().item()
            )
        if self._n_updates % 100 == 1:
            self.tbw.add_scalars(
                'Health/GradNorms',
                {
                    k: v.grad.norm().item()
                    for k, v in self._model.named_parameters()
                    if v.grad is not None
                },
                self.n_samples,
            )

        avg_cr = th.cat(self._cur_rewards).mean().item()
        log.info(
            f'Sample {self._n_samples}, up {self._n_updates*self._num_updates}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {self._log_alpha.exp().mean().item():.03f}'
        )