in hucc/agents/hsdb.py [0:0]
def _update(self):
def act_logp_c(obs, mask):
dist = self._model_pi_c(obs)
action = dist.rsample()
if mask is not None:
log_prob = (dist.log_prob(action) * mask).sum(
dim=-1
) / mask.sum(dim=-1)
action = action * mask * self._action_factor_c
else:
log_prob = dist.log_prob(action).sum(dim=-1)
action = action * self._action_factor_c
return action, log_prob
def q_target(batch):
reward = batch['reward']
not_done = batch['not_done']
obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
alpha_c = self._log_alpha_c.detach().exp()
bsz = reward.shape[0]
action_c, log_prob_c = act_logp_c(obs_p, self._action_c_mask)
action_d = batch[f'action_hi_{self._dkey}']
obs_p[self._dkey] = self.action_hi_d_qinput(action_d).view(-1, nd)
action_c = dim_select(action_c, 1, action_d).view(
-1, action_c.shape[-1]
)
log_prob_c = dim_select(log_prob_c, 1, action_d)
obs_p[self._ckey] = action_c
q_t = th.min(self._target.hi.q(obs_p), dim=-1).values
if self._action_c_mask is not None:
ac = alpha_c.index_select(0, action_d)
else:
ac = alpha_c
v_est = q_t - ac * log_prob_c
return reward + batch['gamma_exp'] * not_done * v_est
for p in self._model.parameters():
mdevice = p.device
break
bsz = self._bsz
nd = self._action_space_d.n
if self._onehots is None:
self._onehots = F.one_hot(th.arange(nd), nd).float().to(mdevice)
if not self._dyne_updates:
assert (
self._buffer.start == 0 or self._buffer.size == self._buffer.max
)
indices = th.where(
self._buffer._b['obs_time'][: self._buffer.size] == 0
)[0]
gbatch = None
if self._dyne_updates and self._bsz < 512:
gbatch = self._buffer.get_batch(
self._bsz * self._num_updates,
device=mdevice,
)
for i in range(self._num_updates):
if self._dyne_updates:
if gbatch is not None:
batch = {
k: v.narrow(0, i * self._bsz, self._bsz)
for k, v in gbatch.items()
}
else:
batch = self._buffer.get_batch(
self._bsz,
device=mdevice,
)
else:
batch = self._buffer.get_batch_where(
self._bsz, indices=indices, device=mdevice
)
obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
alpha_c = self._log_alpha_c.detach().exp()
# Backup for Q-Function
with th.no_grad():
backup = q_target(batch)
# Q-Function update
q_in = copy(obs)
q_in[self._dkey] = self.action_hi_d_qinput(
batch[f'action_hi_{self._dkey}']
)
q_in[self._ckey] = batch[f'action_hi_{self._ckey}']
q = self._q_hi(q_in)
q1 = q[:, 0]
q2 = q[:, 1]
q1_loss = F.mse_loss(q1, backup, reduction='none')
q2_loss = F.mse_loss(q2, backup, reduction='none')
q_loss = q1_loss.mean() + q2_loss.mean()
self._optim.hi.q.zero_grad()
q_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model.q.parameters(), self._clip_grad_norm
)
self._optim.hi.q.step()
# Policy update
for param in self._model.hi.q.parameters():
param.requires_grad_(False)
# No time input for policy, and Q-functions are queried as if step
# would be 0 (i.e. we would take an action)
obs['time'] = obs['time'] * 0
action_c, log_prob_c = act_logp_c(obs, self._action_c_mask)
action_d = batch[f'action_hi_{self._dkey}']
obs[self._dkey] = self.action_hi_d_qinput(action_d).view(-1, nd)
action_c = dim_select(action_c, 1, action_d).view(
-1, action_c.shape[-1]
)
log_prob_c = dim_select(log_prob_c, 1, action_d)
obs[self._ckey] = action_c
q = th.min(self._q_hi(obs), dim=-1).values
if self._action_c_mask is not None:
ac = alpha_c.index_select(0, action_d)
else:
ac = alpha_c
pi_loss = ac * log_prob_c - q
pi_loss = pi_loss.mean()
self._optim_pi_c.zero_grad()
pi_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model_pi_c.parameters(), self._clip_grad_norm
)
self._optim_pi_c.step()
for param in self._model.hi.q.parameters():
param.requires_grad_(True)
# Optional temperature update
if self._optim_alpha_c:
log_alpha = self._log_alpha_c.index_select(0, action_d)
alpha_loss_c = -(
log_alpha.exp()
* (log_prob_c.view(-1).detach() + self._target_entropy_c)
).mean()
self._optim_alpha_c.zero_grad()
alpha_loss_c.backward()
self._optim_alpha_c.step()
# Update target network
with th.no_grad():
for tp, p in zip(
self._target.hi.q.parameters(),
self._model.hi.q.parameters(),
):
tp.data.lerp_(p.data, 1.0 - self._polyak)
dist_d = self._bandit_d.dist()
# These are the stats for the last update
self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
self.tbw_add_scalar('LossHi/QValue', q_loss.item())
with th.no_grad():
bvar = backup.var()
resvar1 = (backup - q1).var() / bvar
resvar2 = (backup - q2).var() / bvar
self.tbw_add_scalar('HealthHi/ResidualVariance1', resvar1.item())
self.tbw_add_scalar('HealthHi/ResidualVariance2', resvar2.item())
self.tbw_add_scalar('HealthHi/EntropyC', -log_prob_c.mean())
self.tbw_add_scalar('HealthHi/EntropyD', dist_d.entropy())
if self._optim_alpha_c:
self.tbw_add_scalar(
'HealthHi/AlphaC', self._log_alpha_c.exp().mean().item()
)
if self._n_updates % 10 == 1:
self.tbw.add_histogram(
'HealthHi/PiD',
th.multinomial(
dist_d.probs,
int(np.ceil(1000 / self._bsz)),
replacement=True,
).view(-1),
self._n_samples,
bins=nd,
)
if self._n_updates % 100 == 1:
self.tbw.add_scalars(
'HealthHi/GradNorms',
{
k: v.grad.norm().item()
for k, v in self._model.named_parameters()
if v.grad is not None
},
self.n_samples,
)
td_err1 = q1_loss.sqrt().mean().item()
td_err2 = q2_loss.sqrt().mean().item()
td_err = (td_err1 + td_err2) / 2
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain', td_err)
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain1', td_err1)
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain2', td_err2)
avg_cr = th.cat(self._cur_rewards).mean().item()
log_stats = [
('Sample', f'{self._n_samples}'),
('hi: up', f'{self._n_updates*self._num_updates}'),
('avg rew', f'{avg_cr:+0.3f}'),
('pi loss', f'{pi_loss.item():+.03f}'),
('q loss', f'{q_loss.item():+.03f}'),
(
'entropy',
f'{-log_prob_c.mean().item():.03f},{dist_d.entropy().item():.03f}',
),
('alpha', f'{self._log_alpha_c.mean().exp().item():.03f}'),
]
log.info(', '.join((f'{k} {v}' for k, v in log_stats)))