in hucc/agents/hsd3.py [0:0]
def _update(self):
def act_logp_c(obs, mask):
dist = self._model_pi_c(obs)
action = dist.rsample()
if mask is not None:
log_prob = (dist.log_prob(action) * mask).sum(
dim=-1
) / mask.sum(dim=-1)
action = action * mask * self._action_factor_c
else:
log_prob = dist.log_prob(action).sum(dim=-1)
action = action * self._action_factor_c
return action, log_prob
def q_target(batch):
reward = batch['reward']
not_done = batch['not_done']
obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
alpha_c = self._log_alpha_c.detach().exp()
alpha_d = self._log_alpha_d.detach().exp()
bsz = reward.shape[0]
d_batchin = self._d_batchin.narrow(0, 0, bsz * nd)
c_batchmask = self._c_batchmask.narrow(0, 0, bsz * nd)
dist_d = self._model_pi_d(obs_p)
action_c, log_prob_c = act_logp_c(obs_p, self._action_c_mask)
if self._expectation_d == -1 and nd > 1:
# Present interleaved observation so that we can easily
# reshape the result into BxA1xA2.
obs_pe = {}
for k, v in obs_p.items():
obs_pe[k] = v.repeat_interleave(nd, dim=0)
obs_pe[self._dkey] = d_batchin
obs_pe[self._ckey] = action_c.view(d_batchin.shape[0], -1)
q_t = th.min(self._target.hi.q(obs_pe), dim=-1).values
q_t = q_t.view(bsz, nd)
log_prob_c = log_prob_c.view(bsz, nd)
v_est = (dist_d.probs * (q_t - log_prob_c * alpha_c)).sum(
dim=-1
) + alpha_d * (dist_d.entropy() - self._uniform_entropy_d)
else:
action_d = th.multinomial(dist_d.probs, nds, replacement=True)
log_prob_d = dist_d.logits.gather(1, action_d)
obs_pe = {}
for k, v in obs_p.items():
if nds > 1:
obs_pe[k] = v.repeat_interleave(nds, dim=0)
else:
obs_pe[k] = v
obs_pe[self._dkey] = self.action_hi_d_qinput(action_d).view(
-1, nd
)
action_c = dim_select(action_c, 1, action_d).view(
-1, action_c.shape[-1]
)
log_prob_c = log_prob_c.gather(1, action_d)
obs_pe[self._ckey] = action_c
q_t = th.min(self._target.hi.q(obs_pe), dim=-1).values.view(
-1, nds
)
log_prob_c = log_prob_c.view(-1, nds)
if self._action_c_mask is not None:
ac = alpha_c.index_select(0, action_d.view(-1)).view_as(
log_prob_c
)
else:
ac = alpha_c
v_est = (q_t - ac * log_prob_c - alpha_d * log_prob_d).mean(
dim=-1
)
return reward + batch['gamma_exp'] * not_done * v_est
for p in self._model.parameters():
mdevice = p.device
break
bsz = self._bsz
nd = self._action_space_d.n
nds = self._expectation_d
if nd == 1:
nds = 1
if self._d_batchin is None:
self._onehots = F.one_hot(th.arange(nd), nd).float().to(mdevice)
self._d_batchin = self.action_hi_d_qinput(
th.arange(bsz * nd).remainder(nd).to(mdevice)
)
if self._action_c_mask is not None:
self._c_batchmask = self._action_c_mask.index_select(
1, th.arange(bsz * nd, device=mdevice).remainder(nd)
).squeeze(0)
else:
self._c_batchmask = None
if not self._dyne_updates:
assert (
self._buffer.start == 0 or self._buffer.size == self._buffer.max
)
indices = th.where(
self._buffer._b['obs_time'][: self._buffer.size] == 0
)[0]
gbatch = None
if self._dyne_updates and self._bsz < 512:
gbatch = self._buffer.get_batch(
self._bsz * self._num_updates,
device=mdevice,
)
for i in range(self._num_updates):
if self._dyne_updates:
if gbatch is not None:
batch = {
k: v.narrow(0, i * self._bsz, self._bsz)
for k, v in gbatch.items()
}
else:
batch = self._buffer.get_batch(
self._bsz,
device=mdevice,
)
else:
batch = self._buffer.get_batch_where(
self._bsz, indices=indices, device=mdevice
)
obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
alpha_c = self._log_alpha_c.detach().exp()
alpha_d = self._log_alpha_d.detach().exp()
# Backup for Q-Function
with th.no_grad():
backup = q_target(batch)
# Q-Function update
q_in = copy(obs)
q_in[self._dkey] = self.action_hi_d_qinput(
batch[f'action_hi_{self._dkey}']
)
q_in[self._ckey] = batch[f'action_hi_{self._ckey}']
q = self._q_hi(q_in)
q1 = q[:, 0]
q2 = q[:, 1]
q1_loss = F.mse_loss(q1, backup, reduction='none')
q2_loss = F.mse_loss(q2, backup, reduction='none')
q_loss = q1_loss.mean() + q2_loss.mean()
self._optim.hi.q.zero_grad()
q_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model.q.parameters(), self._clip_grad_norm
)
self._optim.hi.q.step()
# Policy update
for param in self._model.hi.q.parameters():
param.requires_grad_(False)
# No time input for policy, and Q-functions are queried as if step
# would be 0 (i.e. we would take an action)
obs['time'] = obs['time'] * 0
dist_d = self._model_pi_d(obs)
action_c, log_prob_c = act_logp_c(obs, self._action_c_mask)
if self._expectation_d == -1 and nd > 1:
obs_e = {}
for k, v in obs.items():
obs_e[k] = v.repeat_interleave(nd, dim=0)
obs_e[self._dkey] = self._d_batchin
obs_e[self._ckey] = action_c.view(self._d_batchin.shape[0], -1)
q = th.min(self._q_hi(obs_e), dim=-1).values
q = q.view(bsz, nd)
log_prob_c = log_prob_c.view(bsz, nd)
pi_loss = (dist_d.probs * (alpha_c * log_prob_c - q)).sum(
dim=-1
) - alpha_d * (dist_d.entropy() - self._uniform_entropy_d)
else:
action_d = th.multinomial(dist_d.probs, nds, replacement=True)
log_prob_d = dist_d.logits.gather(1, action_d)
obs_e = {}
for k, v in obs.items():
if nds > 1:
obs_e[k] = v.repeat_interleave(nds, dim=0)
else:
obs_e[k] = v
obs_e[self._dkey] = self.action_hi_d_qinput(action_d).view(
-1, nd
)
action_c = dim_select(action_c, 1, action_d).view(
-1, action_c.shape[-1]
)
log_prob_co = log_prob_c
log_prob_c = log_prob_c.gather(1, action_d)
obs_e[self._ckey] = action_c
q = th.min(self._q_hi(obs_e), dim=-1).values.view(-1, nds)
log_prob_c = log_prob_c.view(-1, nds)
if self._action_c_mask is not None:
ac = alpha_c.index_select(0, action_d.view(-1)).view_as(
log_prob_c
)
else:
ac = alpha_c
pi_loss = (ac * log_prob_c + alpha_d * log_prob_d - q).mean(
dim=-1
)
pi_loss = pi_loss.mean()
self._optim_pi_c.zero_grad()
self._optim_pi_d.zero_grad()
pi_loss.backward()
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model_pi_c.parameters(), self._clip_grad_norm
)
nn.utils.clip_grad_norm_(
self._model_pi_d.parameters(), self._clip_grad_norm
)
self._optim_pi_c.step()
self._optim_pi_d.step()
for param in self._model.hi.q.parameters():
param.requires_grad_(True)
# Optional temperature update
if self._optim_alpha_c:
if self._expectation_d != -1:
alpha_loss_c = (
-(
self._log_alpha_c.exp()
* dist_d.probs.detach()
* (
log_prob_co.detach() + self._target_entropy_c
).view(bsz, nd)
)
.sum(dim=-1)
.mean()
)
else:
alpha_loss_c = (
-(
self._log_alpha_c.exp()
* dist_d.probs.detach()
* (
log_prob_c.detach() + self._target_entropy_c
).view(bsz, nd)
)
.sum(dim=-1)
.mean()
)
self._optim_alpha_c.zero_grad()
alpha_loss_c.backward()
self._optim_alpha_c.step()
if self._optim_alpha_d:
alpha_loss_d = (
self._log_alpha_d.exp()
* (
dist_d.entropy().mean().cpu() - self._target_entropy_d
).detach()
)
self._optim_alpha_d.zero_grad()
alpha_loss_d.backward()
self._optim_alpha_d.step()
# Update target network
with th.no_grad():
for tp, p in zip(
self._target.hi.q.parameters(),
self._model.hi.q.parameters(),
):
tp.data.lerp_(p.data, 1.0 - self._polyak)
# These are the stats for the last update
self.tbw_add_scalar('LossHi/Policy', pi_loss.item())
self.tbw_add_scalar('LossHi/QValue', q_loss.item())
with th.no_grad():
bvar = backup.var()
resvar1 = (backup - q1).var() / bvar
resvar2 = (backup - q2).var() / bvar
self.tbw_add_scalar('HealthHi/ResidualVariance1', resvar1.item())
self.tbw_add_scalar('HealthHi/ResidualVariance2', resvar2.item())
self.tbw_add_scalar('HealthHi/EntropyC', -log_prob_c.mean())
self.tbw_add_scalar('HealthHi/EntropyD', dist_d.entropy().mean())
if self._optim_alpha_c:
self.tbw_add_scalar(
'HealthHi/AlphaC', self._log_alpha_c.exp().mean().item()
)
if self._optim_alpha_d:
self.tbw_add_scalar(
'HealthHi/AlphaD', self._log_alpha_d.exp().item()
)
if self._n_updates % 10 == 1:
self.tbw.add_histogram(
'HealthHi/PiD',
th.multinomial(
dist_d.probs,
int(np.ceil(1000 / self._bsz)),
replacement=True,
).view(-1),
self._n_samples,
bins=nd,
)
if self._n_updates % 100 == 1:
self.tbw.add_scalars(
'HealthHi/GradNorms',
{
k: v.grad.norm().item()
for k, v in self._model.named_parameters()
if v.grad is not None
},
self.n_samples,
)
td_err1 = q1_loss.sqrt().mean().item()
td_err2 = q2_loss.sqrt().mean().item()
td_err = (td_err1 + td_err2) / 2
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain', td_err)
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain1', td_err1)
self.tbw_add_scalar('HealthHi/AbsTDErrorTrain2', td_err2)
avg_cr = th.cat(self._cur_rewards).mean().item()
log_stats = [
('Sample', f'{self._n_samples}'),
('hi: up', f'{self._n_updates*self._num_updates}'),
('avg rew', f'{avg_cr:+0.3f}'),
('pi loss', f'{pi_loss.item():+.03f}'),
('q loss', f'{q_loss.item():+.03f}'),
(
'entropy',
f'{-log_prob_c.mean().item():.03f},{dist_d.entropy().mean().item():.03f}',
),
(
'alpha',
f'{self._log_alpha_c.mean().exp().item():.03f},{self._log_alpha_d.exp().item():.03f}',
),
]
log.info(', '.join((f'{k} {v}' for k, v in log_stats)))