in hucc/agents/hiro.py [0:0]
def action_hi(self, env, obs) -> Tuple[th.Tensor, th.Tensor]:
step = obs['time'].remainder(self._action_interval_hi).long().view(-1)
keep_action = step != 0
gstate_new = obs[self._gspace_key][:, self._goal_features]
action = env.ctx.get('action_hi', None)
gstate = env.ctx.get('gstate_hi', None)
if action is None and keep_action.any().item():
raise RuntimeError('Need to take first action at time=0')
# Goal transition
if gstate is not None:
action = gstate + action - gstate_new
if action is None or not keep_action.all().item():
if self._n_samples < self._randexp_samples and self.training:
new_action = th.stack(
[
th.from_numpy(self._action_space_hi.sample())
for i in range(env.num_envs)
]
).to(list(self._model.parameters())[0].device)
else:
obs_wo_time = copy(obs)
if self._dense_hi_updates:
obs_wo_time['time'] = th.zeros_like(obs_wo_time['time'])
else:
del obs_wo_time['time']
dist = self._model.hi.pi(obs_wo_time)
assert (
dist.has_rsample
), f'rsample() required for hi-level policy distribution'
if self.training:
new_action = dist.sample()
else:
new_action = dist.mean
new_action = self.scale_action_hi(new_action)
if action is None:
action = new_action
else:
m = keep_action.unsqueeze(1)
action = m * action + th.logical_not(m) * new_action
env.ctx['action_hi'] = action
env.ctx['gstate_hi'] = gstate_new
return action, th.logical_not(keep_action)