def action_hi()

in hucc/agents/hiro.py [0:0]


    def action_hi(self, env, obs) -> Tuple[th.Tensor, th.Tensor]:
        step = obs['time'].remainder(self._action_interval_hi).long().view(-1)
        keep_action = step != 0
        gstate_new = obs[self._gspace_key][:, self._goal_features]

        action = env.ctx.get('action_hi', None)
        gstate = env.ctx.get('gstate_hi', None)
        if action is None and keep_action.any().item():
            raise RuntimeError('Need to take first action at time=0')
        # Goal transition
        if gstate is not None:
            action = gstate + action - gstate_new
        if action is None or not keep_action.all().item():
            if self._n_samples < self._randexp_samples and self.training:
                new_action = th.stack(
                    [
                        th.from_numpy(self._action_space_hi.sample())
                        for i in range(env.num_envs)
                    ]
                ).to(list(self._model.parameters())[0].device)
            else:
                obs_wo_time = copy(obs)
                if self._dense_hi_updates:
                    obs_wo_time['time'] = th.zeros_like(obs_wo_time['time'])
                else:
                    del obs_wo_time['time']
                dist = self._model.hi.pi(obs_wo_time)
                assert (
                    dist.has_rsample
                ), f'rsample() required for hi-level policy distribution'
                if self.training:
                    new_action = dist.sample()
                else:
                    new_action = dist.mean
                new_action = self.scale_action_hi(new_action)
            if action is None:
                action = new_action
            else:
                m = keep_action.unsqueeze(1)
                action = m * action + th.logical_not(m) * new_action
        env.ctx['action_hi'] = action
        env.ctx['gstate_hi'] = gstate_new

        return action, th.logical_not(keep_action)