in hucc/agents/hiro.py [0:0]
def _update_hi(self):
model = self._model.hi
target = self._target.hi
optim = self._optim.hi
if self._gspace_min_th is None:
self._gspace_min_th = th.tensor(
self._gspace_min, device=self._buffer.device
)
self._gspace_max_th = th.tensor(
self._gspace_max, device=self._buffer.device
)
obs_keys = self._obs_keys
def act_logp(obs):
dist = model.pi(obs)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(dim=-1)
action = self.scale_action_hi(action)
return action, log_prob
bsz = self._bsz
if self._dense_hi_updates:
n = self._num_updates
else:
n = int(np.ceil(self._num_updates / self._action_interval_hi))
it = 0
while it < n:
c = self._action_interval_hi
k = c * 2 - 1
batch = self._buffer.get_trajs(bsz, k)
# Grab transitions from step 0 to self._action_interval_hi - 1 or
# until a terminal state
step = batch['step']
acc = th.zeros_like(step)
acc[:, 0] = step[:, 0] == 0
for i in range(1, k):
acc[:, i] = acc[:, i - 1] + (step[:, i] == 0)
mask = acc == 1
obs_idx = th.where(th.logical_and(step == 0, mask))[1]
m_terminal = th.logical_and(batch['terminal'], mask)
m_last_step = th.logical_and(step == c - 1, mask)
obs_p_idx = th.where(th.logical_or(m_terminal, m_last_step))[1]
if obs_idx.shape != obs_p_idx.shape:
# We might run into this condition when continuing from a
# checkpoint. Since environments will be reset, we might end up
# with our c*2 - 1 not catching full high-level transitions.
# This is quite a hotfix; another solution would be to use some
# staging logic for transitions, but since this should happen
# rarely let's just do this instead.
it -= 1
continue
not_done = th.logical_not(
dim_select(batch['terminal'], 1, obs_p_idx)
)
if self._relabel_goals:
action_hi = self._relabel_goal(batch, mask, obs_idx, obs_p_idx)
else:
action_hi = dim_select(batch['action_hi'], 1, obs_idx)
if self._dense_hi_updates:
off = th.randint(
c, obs_idx.shape, device=obs_idx.device
).remainder(obs_p_idx - obs_idx + 1)
obs_idx_off = obs_idx + off
obs = {
k: dim_select(batch[f'obs_{k}'], 1, obs_idx_off)
for k in obs_keys
}
obs['time'] = dim_select(batch['step'], 1, obs_idx_off)
obs_p = {
k: dim_select(batch[f'next_obs_{k}'], 1, obs_p_idx)
for k in obs_keys
}
obs_p['time'] = th.zeros_like(obs['time'])
reward = dim_select(batch['reward'], 1, obs_idx_off)
for i in range(1, c):
obs_idx_off_i = obs_idx_off
reward += (
self._gamma ** i
* dim_select(
batch['reward'], 1, (obs_idx_off_i).min(obs_p_idx)
)
* (obs_idx_off_i <= obs_p_idx)
)
gamma = th.zeros_like(reward) + self._gamma
gamma.pow_(obs_p_idx - obs_idx_off + 1)
else:
obs = {
k: dim_select(batch[f'obs_{k}'], 1, obs_idx)
for k in obs_keys
}
reward = (batch['reward'] * mask).sum(dim=1) / mask.sum(dim=1)
gamma = self._gamma
obs_p = {
k: dim_select(batch[f'next_obs_{k}'], 1, obs_p_idx)
for k in obs_keys
}
# Backup for Q-Function
with th.no_grad():
a_p, log_prob_p = act_logp(obs_p)
q_in = dict(action=a_p, **obs_p)
q_tgt = th.min(target.q(q_in), dim=-1).values
backup = reward + gamma * not_done * (
q_tgt - self._log_alpha_hi.detach().exp() * log_prob_p
)
# Q-Function update
q_in = dict(action=action_hi, **obs)
q = model.q(q_in)
q1 = q[:, 0]
q2 = q[:, 1]
q1_loss = F.mse_loss(q1, backup)
q2_loss = F.mse_loss(q2, backup)
q_loss = q1_loss + q2_loss
optim.q.zero_grad()
q_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
model.q.parameters(), self._clip_grad_norm
)
optim.q.step()
# Policy update
for param in model.q.parameters():
param.requires_grad_(False)
if self._dense_hi_updates:
# No time input for policy, and Q-functions are queried as if step
# would be 0 (i.e. we would take an action)
obs['time'] = obs['time'] * 0
a, log_prob = act_logp(obs)
q_in = dict(action=a, **obs)
q = th.min(model.q(q_in), dim=-1).values
pi_loss = (self._log_alpha_hi.detach().exp() * log_prob - q).mean()
optim.pi.zero_grad()
pi_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
model.pi.parameters(), self._clip_grad_norm
)
optim.pi.step()
for param in model.q.parameters():
param.requires_grad_(True)
# Optional temperature update
if self._optim_alpha_hi:
alpha_loss = -(
self._log_alpha_hi.exp()
* (log_prob.mean().cpu() + self._target_entropy_hi).detach()
)
self._optim_alpha_hi.zero_grad()
alpha_loss.backward()
self._optim_alpha_hi.step()
# Update target network
with th.no_grad():
for tp, p in zip(target.q.parameters(), model.q.parameters()):
tp.data.lerp_(p.data, 1.0 - self._polyak)
it += 1
# These are the stats for the last update
self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
self.tbw_add_scalar('LossHi/QValue', q_loss.item())
self.tbw_add_scalar('HealthHi/Entropy', -log_prob.mean())
if self._optim_alpha_hi:
self.tbw_add_scalar(
'HealthHi/Alpha', self._log_alpha_hi.exp().item()
)
self.tbw.add_scalars(
'HealthHi/GradNorms',
{
k: v.grad.norm().item()
for k, v in self._model.named_parameters()
if v.grad is not None
},
self.n_samples,
)
avg_cr = th.cat(self._cur_rewards).mean().item()
log.info(
f'Sample {self._n_samples} hi: up {self._n_updates*n}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {self._log_alpha_hi.exp().item():.03f}'
)