def _update()

in hucc/agents/hsdb.py [0:0]


    def _update(self):
        def act_logp_c(obs, mask):
            dist = self._model_pi_c(obs)
            action = dist.rsample()
            if mask is not None:
                log_prob = (dist.log_prob(action) * mask).sum(
                    dim=-1
                ) / mask.sum(dim=-1)
                action = action * mask * self._action_factor_c
            else:
                log_prob = dist.log_prob(action).sum(dim=-1)
                action = action * self._action_factor_c
            return action, log_prob

        def q_target(batch):
            reward = batch['reward']
            not_done = batch['not_done']
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()
            bsz = reward.shape[0]

            action_c, log_prob_c = act_logp_c(obs_p, self._action_c_mask)

            action_d = batch[f'action_hi_{self._dkey}']
            obs_p[self._dkey] = self.action_hi_d_qinput(action_d).view(-1, nd)

            action_c = dim_select(action_c, 1, action_d).view(
                -1, action_c.shape[-1]
            )
            log_prob_c = dim_select(log_prob_c, 1, action_d)
            obs_p[self._ckey] = action_c

            q_t = th.min(self._target.hi.q(obs_p), dim=-1).values
            if self._action_c_mask is not None:
                ac = alpha_c.index_select(0, action_d)
            else:
                ac = alpha_c
            v_est = q_t - ac * log_prob_c

            return reward + batch['gamma_exp'] * not_done * v_est

        for p in self._model.parameters():
            mdevice = p.device
            break
        bsz = self._bsz
        nd = self._action_space_d.n
        if self._onehots is None:
            self._onehots = F.one_hot(th.arange(nd), nd).float().to(mdevice)

        if not self._dyne_updates:
            assert (
                self._buffer.start == 0 or self._buffer.size == self._buffer.max
            )
            indices = th.where(
                self._buffer._b['obs_time'][: self._buffer.size] == 0
            )[0]
        gbatch = None
        if self._dyne_updates and self._bsz < 512:
            gbatch = self._buffer.get_batch(
                self._bsz * self._num_updates,
                device=mdevice,
            )

        for i in range(self._num_updates):
            if self._dyne_updates:
                if gbatch is not None:
                    batch = {
                        k: v.narrow(0, i * self._bsz, self._bsz)
                        for k, v in gbatch.items()
                    }
                else:
                    batch = self._buffer.get_batch(
                        self._bsz,
                        device=mdevice,
                    )
            else:
                batch = self._buffer.get_batch_where(
                    self._bsz, indices=indices, device=mdevice
                )

            obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
            alpha_c = self._log_alpha_c.detach().exp()

            # Backup for Q-Function
            with th.no_grad():
                backup = q_target(batch)

            # Q-Function update
            q_in = copy(obs)
            q_in[self._dkey] = self.action_hi_d_qinput(
                batch[f'action_hi_{self._dkey}']
            )
            q_in[self._ckey] = batch[f'action_hi_{self._ckey}']
            q = self._q_hi(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup, reduction='none')
            q2_loss = F.mse_loss(q2, backup, reduction='none')
            q_loss = q1_loss.mean() + q2_loss.mean()
            self._optim.hi.q.zero_grad()
            q_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.q.parameters(), self._clip_grad_norm
                )
            self._optim.hi.q.step()

            # Policy update
            for param in self._model.hi.q.parameters():
                param.requires_grad_(False)

            # No time input for policy, and Q-functions are queried as if step
            # would be 0 (i.e. we would take an action)
            obs['time'] = obs['time'] * 0
            action_c, log_prob_c = act_logp_c(obs, self._action_c_mask)
            action_d = batch[f'action_hi_{self._dkey}']
            obs[self._dkey] = self.action_hi_d_qinput(action_d).view(-1, nd)

            action_c = dim_select(action_c, 1, action_d).view(
                -1, action_c.shape[-1]
            )
            log_prob_c = dim_select(log_prob_c, 1, action_d)
            obs[self._ckey] = action_c

            q = th.min(self._q_hi(obs), dim=-1).values
            if self._action_c_mask is not None:
                ac = alpha_c.index_select(0, action_d)
            else:
                ac = alpha_c
            pi_loss = ac * log_prob_c - q

            pi_loss = pi_loss.mean()
            self._optim_pi_c.zero_grad()
            pi_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model_pi_c.parameters(), self._clip_grad_norm
                )
            self._optim_pi_c.step()

            for param in self._model.hi.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha_c:
                log_alpha = self._log_alpha_c.index_select(0, action_d)
                alpha_loss_c = -(
                    log_alpha.exp()
                    * (log_prob_c.view(-1).detach() + self._target_entropy_c)
                ).mean()
                self._optim_alpha_c.zero_grad()
                alpha_loss_c.backward()
                self._optim_alpha_c.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(
                    self._target.hi.q.parameters(),
                    self._model.hi.q.parameters(),
                ):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

        dist_d = self._bandit_d.dist()

        # These are the stats for the last update
        self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
        self.tbw_add_scalar('LossHi/QValue', q_loss.item())
        with th.no_grad():
            bvar = backup.var()
            resvar1 = (backup - q1).var() / bvar
            resvar2 = (backup - q2).var() / bvar
        self.tbw_add_scalar('HealthHi/ResidualVariance1', resvar1.item())
        self.tbw_add_scalar('HealthHi/ResidualVariance2', resvar2.item())
        self.tbw_add_scalar('HealthHi/EntropyC', -log_prob_c.mean())
        self.tbw_add_scalar('HealthHi/EntropyD', dist_d.entropy())
        if self._optim_alpha_c:
            self.tbw_add_scalar(
                'HealthHi/AlphaC', self._log_alpha_c.exp().mean().item()
            )
        if self._n_updates % 10 == 1:
            self.tbw.add_histogram(
                'HealthHi/PiD',
                th.multinomial(
                    dist_d.probs,
                    int(np.ceil(1000 / self._bsz)),
                    replacement=True,
                ).view(-1),
                self._n_samples,
                bins=nd,
            )
        if self._n_updates % 100 == 1:
            self.tbw.add_scalars(
                'HealthHi/GradNorms',
                {
                    k: v.grad.norm().item()
                    for k, v in self._model.named_parameters()
                    if v.grad is not None
                },
                self.n_samples,
            )

        td_err1 = q1_loss.sqrt().mean().item()
        td_err2 = q2_loss.sqrt().mean().item()
        td_err = (td_err1 + td_err2) / 2
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain', td_err)
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain1', td_err1)
        self.tbw_add_scalar('HealthHi/AbsTDErrorTrain2', td_err2)

        avg_cr = th.cat(self._cur_rewards).mean().item()
        log_stats = [
            ('Sample', f'{self._n_samples}'),
            ('hi: up', f'{self._n_updates*self._num_updates}'),
            ('avg rew', f'{avg_cr:+0.3f}'),
            ('pi loss', f'{pi_loss.item():+.03f}'),
            ('q loss', f'{q_loss.item():+.03f}'),
            (
                'entropy',
                f'{-log_prob_c.mean().item():.03f},{dist_d.entropy().item():.03f}',
            ),
            ('alpha', f'{self._log_alpha_c.mean().exp().item():.03f}'),
        ]
        log.info(', '.join((f'{k} {v}' for k, v in log_stats)))