def _update_hi()

in hucc/agents/hiro.py [0:0]


    def _update_hi(self):
        model = self._model.hi
        target = self._target.hi
        optim = self._optim.hi

        if self._gspace_min_th is None:
            self._gspace_min_th = th.tensor(
                self._gspace_min, device=self._buffer.device
            )
            self._gspace_max_th = th.tensor(
                self._gspace_max, device=self._buffer.device
            )
        obs_keys = self._obs_keys

        def act_logp(obs):
            dist = model.pi(obs)
            action = dist.rsample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            action = self.scale_action_hi(action)
            return action, log_prob

        bsz = self._bsz
        if self._dense_hi_updates:
            n = self._num_updates
        else:
            n = int(np.ceil(self._num_updates / self._action_interval_hi))
        it = 0
        while it < n:
            c = self._action_interval_hi
            k = c * 2 - 1
            batch = self._buffer.get_trajs(bsz, k)

            # Grab transitions from step 0 to self._action_interval_hi - 1 or
            # until a terminal state
            step = batch['step']
            acc = th.zeros_like(step)
            acc[:, 0] = step[:, 0] == 0
            for i in range(1, k):
                acc[:, i] = acc[:, i - 1] + (step[:, i] == 0)
            mask = acc == 1
            obs_idx = th.where(th.logical_and(step == 0, mask))[1]
            m_terminal = th.logical_and(batch['terminal'], mask)
            m_last_step = th.logical_and(step == c - 1, mask)
            obs_p_idx = th.where(th.logical_or(m_terminal, m_last_step))[1]
            if obs_idx.shape != obs_p_idx.shape:
                # We might run into this condition when continuing from a
                # checkpoint. Since environments will be reset, we might end up
                # with our c*2 - 1 not catching full high-level transitions.
                # This is quite a hotfix; another solution would be to use some
                # staging logic for transitions, but since this should happen
                # rarely let's just do this instead.
                it -= 1
                continue

            not_done = th.logical_not(
                dim_select(batch['terminal'], 1, obs_p_idx)
            )

            if self._relabel_goals:
                action_hi = self._relabel_goal(batch, mask, obs_idx, obs_p_idx)
            else:
                action_hi = dim_select(batch['action_hi'], 1, obs_idx)

            if self._dense_hi_updates:
                off = th.randint(
                    c, obs_idx.shape, device=obs_idx.device
                ).remainder(obs_p_idx - obs_idx + 1)
                obs_idx_off = obs_idx + off
                obs = {
                    k: dim_select(batch[f'obs_{k}'], 1, obs_idx_off)
                    for k in obs_keys
                }
                obs['time'] = dim_select(batch['step'], 1, obs_idx_off)
                obs_p = {
                    k: dim_select(batch[f'next_obs_{k}'], 1, obs_p_idx)
                    for k in obs_keys
                }
                obs_p['time'] = th.zeros_like(obs['time'])

                reward = dim_select(batch['reward'], 1, obs_idx_off)
                for i in range(1, c):
                    obs_idx_off_i = obs_idx_off
                    reward += (
                        self._gamma ** i
                        * dim_select(
                            batch['reward'], 1, (obs_idx_off_i).min(obs_p_idx)
                        )
                        * (obs_idx_off_i <= obs_p_idx)
                    )
                gamma = th.zeros_like(reward) + self._gamma
                gamma.pow_(obs_p_idx - obs_idx_off + 1)
            else:
                obs = {
                    k: dim_select(batch[f'obs_{k}'], 1, obs_idx)
                    for k in obs_keys
                }
                reward = (batch['reward'] * mask).sum(dim=1) / mask.sum(dim=1)
                gamma = self._gamma
                obs_p = {
                    k: dim_select(batch[f'next_obs_{k}'], 1, obs_p_idx)
                    for k in obs_keys
                }

            # Backup for Q-Function
            with th.no_grad():
                a_p, log_prob_p = act_logp(obs_p)
                q_in = dict(action=a_p, **obs_p)
                q_tgt = th.min(target.q(q_in), dim=-1).values
                backup = reward + gamma * not_done * (
                    q_tgt - self._log_alpha_hi.detach().exp() * log_prob_p
                )

            # Q-Function update
            q_in = dict(action=action_hi, **obs)
            q = model.q(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup)
            q2_loss = F.mse_loss(q2, backup)
            q_loss = q1_loss + q2_loss
            optim.q.zero_grad()
            q_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    model.q.parameters(), self._clip_grad_norm
                )
            optim.q.step()

            # Policy update
            for param in model.q.parameters():
                param.requires_grad_(False)

            if self._dense_hi_updates:
                # No time input for policy, and Q-functions are queried as if step
                # would be 0 (i.e. we would take an action)
                obs['time'] = obs['time'] * 0
            a, log_prob = act_logp(obs)
            q_in = dict(action=a, **obs)
            q = th.min(model.q(q_in), dim=-1).values
            pi_loss = (self._log_alpha_hi.detach().exp() * log_prob - q).mean()
            optim.pi.zero_grad()
            pi_loss.backward()
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    model.pi.parameters(), self._clip_grad_norm
                )
            optim.pi.step()

            for param in model.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha_hi:
                alpha_loss = -(
                    self._log_alpha_hi.exp()
                    * (log_prob.mean().cpu() + self._target_entropy_hi).detach()
                )
                self._optim_alpha_hi.zero_grad()
                alpha_loss.backward()
                self._optim_alpha_hi.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(target.q.parameters(), model.q.parameters()):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

            it += 1

        # These are the stats for the last update
        self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
        self.tbw_add_scalar('LossHi/QValue', q_loss.item())
        self.tbw_add_scalar('HealthHi/Entropy', -log_prob.mean())
        if self._optim_alpha_hi:
            self.tbw_add_scalar(
                'HealthHi/Alpha', self._log_alpha_hi.exp().item()
            )
        self.tbw.add_scalars(
            'HealthHi/GradNorms',
            {
                k: v.grad.norm().item()
                for k, v in self._model.named_parameters()
                if v.grad is not None
            },
            self.n_samples,
        )

        avg_cr = th.cat(self._cur_rewards).mean().item()
        log.info(
            f'Sample {self._n_samples} hi: up {self._n_updates*n}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {self._log_alpha_hi.exp().item():.03f}'
        )