def action()

in hucc/agents/hsd3.py [0:0]


    def action(self, env, obs) -> Tuple[th.Tensor, Any]:
        step = obs['time'].remainder(self._action_interval).long().view(-1)
        keep_action_hi = step != 0

        def retain(x, y, mask):
            return mask * x + th.logical_not(mask) * y

        prev_gs_obs = env.ctx.get('gs_obs', None)
        action_hi = env.ctx.get('action_hi', None)
        obs_hi = copy(obs)

        tr_action_hi = env.ctx.get('tr_action_hi', None)
        if action_hi is None or not keep_action_hi.all().item():
            with th.no_grad():
                new_action_hi = self.action_hi(env, obs_hi, action_hi)
            tr_new_action_hi = self._iface.translate(
                self._iface.gs_obs(obs),
                new_action_hi[self._dkey],
                new_action_hi[self._ckey],
            )
            if action_hi is None:
                action_hi = deepcopy(new_action_hi)
                tr_action_hi = deepcopy(tr_new_action_hi)
            else:
                c = self._ckey
                d = self._dkey
                # Replace raw actions
                action_hi[d] = retain(
                    action_hi[d], new_action_hi[d], keep_action_hi
                )
                action_hi[c] = retain(
                    action_hi[c],
                    new_action_hi[c],
                    keep_action_hi.unsqueeze(1).expand_as(action_hi[c]),
                )
                # Replace translated actions
                tr_action_hi['task'] = retain(
                    tr_action_hi['task'],
                    tr_new_action_hi['task'],
                    keep_action_hi.unsqueeze(1).expand_as(tr_action_hi['task']),
                )
                tr_action_hi['desired_goal'] = self._iface.update_bp_subgoal(
                    prev_gs_obs, self._iface.gs_obs(obs), tr_action_hi
                )
                tr_action_hi['desired_goal'] = retain(
                    tr_action_hi['desired_goal'],
                    tr_new_action_hi['desired_goal'],
                    keep_action_hi.unsqueeze(1).expand_as(
                        tr_action_hi['desired_goal']
                    ),
                )
        else:
            tr_action_hi['desired_goal'] = self._iface.update_bp_subgoal(
                prev_gs_obs, self._iface.gs_obs(obs), tr_action_hi
            )

        env.ctx['action_hi'] = action_hi
        env.ctx['tr_action_hi'] = tr_action_hi
        if not 'gs_obs' in env.ctx:
            env.ctx['gs_obs'] = self._iface.gs_obs(obs).clone()
        else:
            env.ctx['gs_obs'].copy_(self._iface.gs_obs(obs))

        with th.no_grad():
            obs_lo = self._iface.observation_lo(
                obs['observation'], tr_action_hi
            )
            action_lo = self.action_lo(env, obs_lo)

        if self.training:
            return action_lo, {
                'action_hi': action_hi,
                'tr_action_hi': tr_action_hi,
                #'gs_obs0': env.ctx['gs_obs0'],
                'obs_hi': obs_hi,
            }

        # Additional visualization info for evals
        subsets = [
            self._iface.subsets[i.item()] for i in action_hi['task'].cpu()
        ]
        sg_cpu = action_hi['subgoal'].cpu().numpy()
        sgd_cpu = tr_action_hi['desired_goal'].cpu().numpy()
        subgoals = []
        subgoals_d = []
        for i in range(env.num_envs):
            n = len(subsets[i].split(','))
            subgoals.append(sg_cpu[i, :n])
            feats = [self._iface.task_map[f] for f in subsets[i].split(',')]
            subgoals_d.append(sgd_cpu[i, feats])
        return action_lo, {
            'action_hi': action_hi,
            'tr_action_hi': tr_action_hi,
            'obs_hi': obs_hi,
            'st': subsets,
            'sg': subgoals,
            'sgd': subgoals_d,
            'viz': ['st', 'sg', 'sgd'],
        }