def _update()

in hucc/agents/diayn.py [0:0]


    def _update(self):
        def act_logp(obs):
            dist = self._model.pi(obs)
            action = dist.rsample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            action = action * self._action_factor
            return action, log_prob

        rewards = []
        for _ in range(self._num_updates):
            batch = self._buffer.get_batch(self._bsz)
            # Ensure that action has a batch dimension
            action = batch['action'].view(batch['obs'].shape[0], -1)
            z = batch['z']
            z_one_hot = F.one_hot(z, self._n_skills).float()
            phi_obs = batch.get('phi_obs', batch['obs'])
            if self._phi_obs_feats is not None:
                phi_obs = phi_obs[:, self._phi_obs_feats]
            not_done = th.logical_not(batch['terminal'])

            # Compute pseudo-reward with discriminator
            with th.no_grad():
                reward = -F.cross_entropy(
                    self._model.phi(phi_obs), z, reduction='none'
                )
                # Subtract baseline
                if self._add_p_z:
                    reward = reward - self._log_p_z
                rewards.append(reward.mean().item())

            # Backup for Q-Function
            with th.no_grad():
                obs_p = th.cat([batch['next_obs'], z_one_hot], dim=1)
                a_p, log_prob_p = act_logp(obs_p)

                q_in = th.cat([obs_p, a_p], dim=1)
                q_tgt = th.min(self._target.q(q_in), dim=-1).values
                backup = reward + self._gamma * not_done * (
                    q_tgt - self._log_alpha.detach().exp() * log_prob_p
                )

            # Q-Function update
            obs = th.cat([batch['obs'], z_one_hot], dim=1)
            q_in = th.cat([obs, action], dim=1)
            q = self._model.q(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup)
            q2_loss = F.mse_loss(q2, backup)
            q_loss = q1_loss + q2_loss
            self._optim.q.zero_grad()
            q_loss.backward()
            self._optim.q.step()

            # Policy update
            for param in self._model.q.parameters():
                param.requires_grad_(False)

            a, log_prob = act_logp(obs)
            q_in = th.cat([obs, a], dim=1)
            q = th.min(self._model.q(q_in), dim=-1).values
            pi_loss = (self._log_alpha.detach().exp() * log_prob - q).mean()
            self._optim.pi.zero_grad()
            pi_loss.backward()
            self._optim.pi.step()

            for param in self._model.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha:
                # This is slight reording of the formulation in
                # https://github.com/rail-berkeley/softlearning, mostly so we
                # don't need to create temporary tensors. log_prob is the only
                # non-scalar tensor, so we can compute its mean first.
                alpha_loss = -(
                    self._log_alpha.exp()
                    * (log_prob.mean().cpu() + self._target_entropy).detach()
                )
                self._optim_alpha.zero_grad()
                alpha_loss.backward()
                self._optim_alpha.step()

            # Update discriminator
            self._optim.phi.zero_grad()
            phi_loss = F.cross_entropy(self._model.phi(phi_obs), z)
            phi_loss.backward()
            self._optim.phi.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(
                    self._target.q.parameters(), self._model.q.parameters()
                ):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

        self.tbw_add_scalar('Loss/Policy', pi_loss.item())
        self.tbw_add_scalar('Loss/QValue', q_loss.item())
        self.tbw_add_scalar('Loss/Discriminator', phi_loss.item())
        self.tbw_add_scalar('Avg Reward', np.mean(rewards))
        self.tbw_add_scalar('Health/Entropy', log_prob.mean().item())
        if self._optim_alpha:
            self.tbw_add_scalar('Health/Alpha', self._log_alpha.exp().item())

        msg = log.debug
        if (self._n_updates * self._num_updates) % 50 == 0:
            msg = log.info
        msg(
            f'Sample {self._n_samples}, up {self._n_updates*self._num_updates}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, phi loss {phi_loss.item():+.03f}, avg reward {np.mean(rewards):+.03}, alpha {self._log_alpha.exp().item():.03f}'
        )