in hucc/agents/hiro.py [0:0]
def _update_lo(self):
model = self._model.lo
target = self._target.lo
optim = self._optim.lo
def act_logp(obs):
dist = model.pi(obs)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(dim=-1)
action = action * self._action_factor_lo
return action, log_prob
for _ in range(self._num_updates):
batch = self._buffer.get_batch(self._bsz)
reward = batch['reward_lo']
obs = {k: batch[f'obs_{k}'] for k in self._obs_lo_keys}
obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_lo_keys}
for k, m in self._obs_lo_mask.items():
m = m.to(obs[k])
obs[k] = obs[k] * m
obs_p[k] = obs_p[k] * m
self._obs_lo_mask[k] = m
obs['desired_goal'] = batch['action_hi']
obs_p['desired_goal'] = batch['auto_next_action_hi']
not_fell_over = th.logical_not(batch['fell_over'])
# Backup for Q-Function
with th.no_grad():
a_p, log_prob_p = act_logp(obs_p)
q_in = dict(action=a_p, **obs_p)
q_tgt = th.min(target.q(q_in), dim=-1).values
# Assume that low-level epsiodes don't end
backup = reward + self._gamma * not_fell_over * (
q_tgt - self._log_alpha_lo.detach().exp() * log_prob_p
)
# Q-Function update
q_in = dict(action=batch['action_lo'], **obs)
q = model.q(q_in)
q1 = q[:, 0]
q2 = q[:, 1]
q1_loss = F.mse_loss(q1, backup)
q2_loss = F.mse_loss(q2, backup)
q_loss = q1_loss + q2_loss
optim.q.zero_grad()
q_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
model.q.parameters(), self._clip_grad_norm
)
optim.q.step()
# Policy update
for param in model.q.parameters():
param.requires_grad_(False)
a, log_prob = act_logp(obs)
q_in = dict(action=a, **obs)
q = th.min(model.q(q_in), dim=-1).values
pi_loss = (self._log_alpha_lo.detach().exp() * log_prob - q).mean()
optim.pi.zero_grad()
pi_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
model.pi.parameters(), self._clip_grad_norm
)
optim.pi.step()
for param in model.q.parameters():
param.requires_grad_(True)
# Optional temperature update
if self._optim_alpha_lo:
alpha_loss = -(
self._log_alpha_lo.exp()
* (log_prob.mean().cpu() + self._target_entropy_lo).detach()
)
self._optim_alpha_lo.zero_grad()
alpha_loss.backward()
self._optim_alpha_lo.step()
# Update target network
with th.no_grad():
for tp, p in zip(target.q.parameters(), model.q.parameters()):
tp.data.lerp_(p.data, 1.0 - self._polyak)
# These are the stats for the last update
self.tbw_add_scalar('LossLo/Policy', pi_loss.item())
self.tbw_add_scalar('LossLo/QValue', q_loss.item())
self.tbw_add_scalar('HealthLo/Entropy', -log_prob.mean())
if self._optim_alpha_lo:
self.tbw_add_scalar(
'HealthLo/Alpha', self._log_alpha_lo.exp().item()
)
self.tbw.add_scalars(
'HealthLo/GradNorms',
{
k: v.grad.norm().item()
for k, v in self._model.named_parameters()
if v.grad is not None
},
self.n_samples,
)
avg_cr = th.cat(self._cur_rewards_lo).mean().item()
log.info(
f'Sample {self._n_samples} lo: up {self._n_updates*self._num_updates}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {self._log_alpha_lo.exp().item():.03f}'
)