def _update()

in hucc/agents/hsd3.py [0:0]


    def _update(self):
        def act_logp_c(obs, mask):
            dist = self._model_pi_c(obs)
            action = dist.rsample()
            if mask is not None:
                log_prob = (dist.log_prob(action) * mask).sum(
                    dim=-1
                ) / mask.sum(dim=-1)
                action = action * mask * self._action_factor_c
            else:
                log_prob = dist.log_prob(action).sum(dim=-1)
                action = action * self._action_factor_c
            return action, log_prob

        def q_target(batch):
            reward = batch['reward']
            not_done = batch['not_done']
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()
            alpha_d = self._log_alpha_d.detach().exp()
            bsz = reward.shape[0]
            d_batchin = self._d_batchin.narrow(0, 0, bsz * nd)
            c_batchmask = self._c_batchmask.narrow(0, 0, bsz * nd)

            dist_d = self._model_pi_d(obs_p)
            action_c, log_prob_c = act_logp_c(obs_p, self._action_c_mask)

            if self._expectation_d == -1 and nd > 1:
                # Present interleaved observation so that we can easily
                # reshape the result into BxA1xA2.
                obs_pe = {}
                for k, v in obs_p.items():
                    obs_pe[k] = v.repeat_interleave(nd, dim=0)
                obs_pe[self._dkey] = d_batchin
                obs_pe[self._ckey] = action_c.view(d_batchin.shape[0], -1)
                q_t = th.min(self._target.hi.q(obs_pe), dim=-1).values

                q_t = q_t.view(bsz, nd)
                log_prob_c = log_prob_c.view(bsz, nd)
                v_est = (dist_d.probs * (q_t - log_prob_c * alpha_c)).sum(
                    dim=-1
                ) + alpha_d * (dist_d.entropy() - self._uniform_entropy_d)
            else:
                action_d = th.multinomial(dist_d.probs, nds, replacement=True)
                log_prob_d = dist_d.logits.gather(1, action_d)

                obs_pe = {}
                for k, v in obs_p.items():
                    if nds > 1:
                        obs_pe[k] = v.repeat_interleave(nds, dim=0)
                    else:
                        obs_pe[k] = v
                obs_pe[self._dkey] = self.action_hi_d_qinput(action_d).view(
                    -1, nd
                )

                action_c = dim_select(action_c, 1, action_d).view(
                    -1, action_c.shape[-1]
                )
                log_prob_c = log_prob_c.gather(1, action_d)
                obs_pe[self._ckey] = action_c

                q_t = th.min(self._target.hi.q(obs_pe), dim=-1).values.view(
                    -1, nds
                )
                log_prob_c = log_prob_c.view(-1, nds)
                if self._action_c_mask is not None:
                    ac = alpha_c.index_select(0, action_d.view(-1)).view_as(
                        log_prob_c
                    )
                else:
                    ac = alpha_c
                v_est = (q_t - ac * log_prob_c - alpha_d * log_prob_d).mean(
                    dim=-1
                )

            return reward + batch['gamma_exp'] * not_done * v_est

        for p in self._model.parameters():
            mdevice = p.device
            break
        bsz = self._bsz
        nd = self._action_space_d.n
        nds = self._expectation_d
        if nd == 1:
            nds = 1
        if self._d_batchin is None:
            self._onehots = F.one_hot(th.arange(nd), nd).float().to(mdevice)
            self._d_batchin = self.action_hi_d_qinput(
                th.arange(bsz * nd).remainder(nd).to(mdevice)
            )
            if self._action_c_mask is not None:
                self._c_batchmask = self._action_c_mask.index_select(
                    1, th.arange(bsz * nd, device=mdevice).remainder(nd)
                ).squeeze(0)
            else:
                self._c_batchmask = None

        if not self._dyne_updates:
            assert (
                self._buffer.start == 0 or self._buffer.size == self._buffer.max
            )
            indices = th.where(
                self._buffer._b['obs_time'][: self._buffer.size] == 0
            )[0]
        gbatch = None
        if self._dyne_updates and self._bsz < 512:
            gbatch = self._buffer.get_batch(
                self._bsz * self._num_updates,
                device=mdevice,
            )

        for i in range(self._num_updates):
            if self._dyne_updates:
                if gbatch is not None:
                    batch = {
                        k: v.narrow(0, i * self._bsz, self._bsz)
                        for k, v in gbatch.items()
                    }
                else:
                    batch = self._buffer.get_batch(
                        self._bsz,
                        device=mdevice,
                    )
            else:
                batch = self._buffer.get_batch_where(
                    self._bsz, indices=indices, device=mdevice
                )

            obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()
            alpha_d = self._log_alpha_d.detach().exp()

            # Backup for Q-Function
            with th.no_grad():
                backup = q_target(batch)

            # Q-Function update
            q_in = copy(obs)
            q_in[self._dkey] = self.action_hi_d_qinput(
                batch[f'action_hi_{self._dkey}']
            )
            q_in[self._ckey] = batch[f'action_hi_{self._ckey}']
            q = self._q_hi(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup, reduction='none')
            q2_loss = F.mse_loss(q2, backup, reduction='none')
            q_loss = q1_loss.mean() + q2_loss.mean()
            self._optim.hi.q.zero_grad()
            q_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.q.parameters(), self._clip_grad_norm
                )
            self._optim.hi.q.step()

            # Policy update
            for param in self._model.hi.q.parameters():
                param.requires_grad_(False)

            # No time input for policy, and Q-functions are queried as if step
            # would be 0 (i.e. we would take an action)
            obs['time'] = obs['time'] * 0
            dist_d = self._model_pi_d(obs)
            action_c, log_prob_c = act_logp_c(obs, self._action_c_mask)

            if self._expectation_d == -1 and nd > 1:
                obs_e = {}
                for k, v in obs.items():
                    obs_e[k] = v.repeat_interleave(nd, dim=0)
                obs_e[self._dkey] = self._d_batchin
                obs_e[self._ckey] = action_c.view(self._d_batchin.shape[0], -1)
                q = th.min(self._q_hi(obs_e), dim=-1).values

                q = q.view(bsz, nd)
                log_prob_c = log_prob_c.view(bsz, nd)
                pi_loss = (dist_d.probs * (alpha_c * log_prob_c - q)).sum(
                    dim=-1
                ) - alpha_d * (dist_d.entropy() - self._uniform_entropy_d)
            else:
                action_d = th.multinomial(dist_d.probs, nds, replacement=True)
                log_prob_d = dist_d.logits.gather(1, action_d)

                obs_e = {}
                for k, v in obs.items():
                    if nds > 1:
                        obs_e[k] = v.repeat_interleave(nds, dim=0)
                    else:
                        obs_e[k] = v
                obs_e[self._dkey] = self.action_hi_d_qinput(action_d).view(
                    -1, nd
                )

                action_c = dim_select(action_c, 1, action_d).view(
                    -1, action_c.shape[-1]
                )
                log_prob_co = log_prob_c
                log_prob_c = log_prob_c.gather(1, action_d)
                obs_e[self._ckey] = action_c

                q = th.min(self._q_hi(obs_e), dim=-1).values.view(-1, nds)
                log_prob_c = log_prob_c.view(-1, nds)
                if self._action_c_mask is not None:
                    ac = alpha_c.index_select(0, action_d.view(-1)).view_as(
                        log_prob_c
                    )
                else:
                    ac = alpha_c
                pi_loss = (ac * log_prob_c + alpha_d * log_prob_d - q).mean(
                    dim=-1
                )

            pi_loss = pi_loss.mean()
            self._optim_pi_c.zero_grad()
            self._optim_pi_d.zero_grad()
            pi_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model_pi_c.parameters(), self._clip_grad_norm
                )
                nn.utils.clip_grad_norm_(
                    self._model_pi_d.parameters(), self._clip_grad_norm
                )
            self._optim_pi_c.step()
            self._optim_pi_d.step()

            for param in self._model.hi.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha_c:
                if self._expectation_d != -1:
                    alpha_loss_c = (
                        -(
                            self._log_alpha_c.exp()
                            * dist_d.probs.detach()
                            * (
                                log_prob_co.detach() + self._target_entropy_c
                            ).view(bsz, nd)
                        )
                        .sum(dim=-1)
                        .mean()
                    )
                else:
                    alpha_loss_c = (
                        -(
                            self._log_alpha_c.exp()
                            * dist_d.probs.detach()
                            * (
                                log_prob_c.detach() + self._target_entropy_c
                            ).view(bsz, nd)
                        )
                        .sum(dim=-1)
                        .mean()
                    )
                self._optim_alpha_c.zero_grad()
                alpha_loss_c.backward()
                self._optim_alpha_c.step()
            if self._optim_alpha_d:
                alpha_loss_d = (
                    self._log_alpha_d.exp()
                    * (
                        dist_d.entropy().mean().cpu() - self._target_entropy_d
                    ).detach()
                )
                self._optim_alpha_d.zero_grad()
                alpha_loss_d.backward()
                self._optim_alpha_d.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(
                    self._target.hi.q.parameters(),
                    self._model.hi.q.parameters(),
                ):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

        # These are the stats for the last update
        self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
        self.tbw_add_scalar('LossHi/QValue', q_loss.item())
        with th.no_grad():
            bvar = backup.var()
            resvar1 = (backup - q1).var() / bvar
            resvar2 = (backup - q2).var() / bvar
        self.tbw_add_scalar('HealthHi/ResidualVariance1', resvar1.item())
        self.tbw_add_scalar('HealthHi/ResidualVariance2', resvar2.item())
        self.tbw_add_scalar('HealthHi/EntropyC', -log_prob_c.mean())
        self.tbw_add_scalar('HealthHi/EntropyD', dist_d.entropy().mean())
        if self._optim_alpha_c:
            self.tbw_add_scalar(
                'HealthHi/AlphaC', self._log_alpha_c.exp().mean().item()
            )
        if self._optim_alpha_d:
            self.tbw_add_scalar(
                'HealthHi/AlphaD', self._log_alpha_d.exp().item()
            )
        if self._n_updates % 10 == 1:
            self.tbw.add_histogram(
                'HealthHi/PiD',
                th.multinomial(
                    dist_d.probs,
                    int(np.ceil(1000 / self._bsz)),
                    replacement=True,
                ).view(-1),
                self._n_samples,
                bins=nd,
            )
        if self._n_updates % 100 == 1:
            self.tbw.add_scalars(
                'HealthHi/GradNorms',
                {
                    k: v.grad.norm().item()
                    for k, v in self._model.named_parameters()
                    if v.grad is not None
                },
                self.n_samples,
            )

        td_err1 = q1_loss.sqrt().mean().item()
        td_err2 = q2_loss.sqrt().mean().item()
        td_err = (td_err1 + td_err2) / 2
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain', td_err)
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain1', td_err1)
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain2', td_err2)

        avg_cr = th.cat(self._cur_rewards).mean().item()
        log_stats = [
            ('Sample', f'{self._n_samples}'),
            ('hi: up', f'{self._n_updates*self._num_updates}'),
            ('avg rew', f'{avg_cr:+0.3f}'),
            ('pi loss', f'{pi_loss.item():+.03f}'),
            ('q loss', f'{q_loss.item():+.03f}'),
            (
                'entropy',
                f'{-log_prob_c.mean().item():.03f},{dist_d.entropy().mean().item():.03f}',
            ),
            (
                'alpha',
                f'{self._log_alpha_c.mean().exp().item():.03f},{self._log_alpha_d.exp().item():.03f}',
            ),
        ]
        log.info(', '.join((f'{k} {v}' for k, v in log_stats)))