in hucc/agents/sacmt.py [0:0]
def _update(self):
for p in self._model.parameters():
mdevice = p.device
break
def act_logp(obs):
dist = self._model.pi(obs)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(dim=-1)
action = action * self._action_factor
return action, log_prob
def task_map(task_keys):
# Maps task keys in the batch to indices
keys, pos = th.unique(task_keys, return_inverse=True, sorted=False)
keys = keys.cpu().numpy()
pos = pos.cpu().numpy()
key_idx = defaultdict(list)
for i, j in enumerate(pos):
key_idx[keys[j]].append(i)
return key_idx
def alpha_for_tasks(task_keys):
key_idx = task_map(task_keys)
alpha = th.zeros(task_keys.shape[0], dtype=th.float32)
for key, idx in key_idx.items():
alpha[idx] = self._log_alpha[key].detach().exp()
return alpha.to(task_keys.device), key_idx
task_keys: List[th.Tensor] = []
for _ in range(self._num_updates):
self._ups += 1
batch = self._buffer.get_batch(self._bsz, device=mdevice)
obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
reward = batch['reward']
not_done = th.logical_not(batch['terminal'])
# Determine task counts in this batch
if self._per_task_alpha:
with th.no_grad():
task_alpha, task_key_idx = alpha_for_tasks(
batch['task_key']
)
if 'task_key' in batch:
task_keys.append(batch['task_key'])
# Backup for Q-Function
with th.no_grad():
a_p, log_prob_p = act_logp(obs_p)
q_in = dict(action=a_p, **obs_p)
q_tgt = th.min(self._target.q(q_in), dim=-1).values
if not self._per_task_alpha:
alpha = self._log_alpha['_'].detach().exp()
else:
alpha = task_alpha
backup = reward + self._gamma * not_done * (
q_tgt - alpha * log_prob_p
)
# Q-Function update
q_in = dict(action=batch['action'], **obs)
q = self._model.q(q_in)
q1 = q[:, 0]
q2 = q[:, 1]
q1_loss = F.mse_loss(q1, backup, reduction='none')
q2_loss = F.mse_loss(q2, backup, reduction='none')
self._optim.q.zero_grad()
q_loss = q1_loss.mean() + q2_loss.mean()
q_loss.backward()
if self.learner_group:
for p in self._model.q.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.learner_group)
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model.q.parameters(), self._clip_grad_norm
)
self._optim.q.step()
# Policy update
for param in self._model.q.parameters():
param.requires_grad_(False)
a, log_prob = act_logp(obs)
q_in = dict(action=a, **obs)
q = th.min(self._model.q(q_in), dim=-1).values
if not self._per_task_alpha:
alpha = self._log_alpha['_'].detach().exp()
else:
alpha = task_alpha.detach()
pi_loss = alpha * log_prob - q
pi_loss = pi_loss.mean()
self._optim.pi.zero_grad()
pi_loss.backward()
if self.learner_group:
for p in self._model.pi.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.learner_group)
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model.pi.parameters(), self._clip_grad_norm
)
self._optim.pi.step()
for param in self._model.q.parameters():
param.requires_grad_(True)
# Optional temperature update
if self._optim_alpha is not None:
def optim_alpha(key):
if not key in self._optim_alpha:
self._optim_alpha[key] = hydra.utils.instantiate(
self._cfg_optim_alpha, [self._log_alpha[key]]
)
return self._optim_alpha[key]
if not self._per_task_alpha:
# This is slight reording of the formulation in
# https://github.com/rail-berkeley/softlearning, mostly so we
# don't need to create temporary tensors. log_prob is the only
# non-scalar tensor, so we can compute its mean first.
alpha_loss = -(
self._log_alpha['_'].exp()
* (
log_prob.mean().cpu() + self._target_entropy
).detach()
)
optim_alpha('_').zero_grad()
alpha_loss.backward()
optim_alpha('_').step()
else:
for key, idx in task_key_idx.items():
alpha_loss = -(
self._log_alpha[key].exp()
* (
log_prob[idx].mean().cpu()
+ self._target_entropy
).detach()
)
optim_alpha(key).zero_grad()
alpha_loss.backward()
optim_alpha(key).step()
# Update reachability network via TD learning
if self._update_reachability and hasattr(
self._model, 'reachability'
):
with th.no_grad():
a_p, log_prob_p = act_logp(obs_p)
r_in = dict(action=a_p, **obs_p)
r_tgt = self._target.reachability(r_in).view(-1)
propagate = th.logical_not(batch['last_step_of_task'])
r_backup = batch['reached_goal'] + propagate * r_tgt
r_in = dict(action=batch['action'], **obs)
r_est = self._model.reachability(r_in).view(-1)
r_loss = F.mse_loss(r_est, r_backup, reduction='mean')
self._optim.reachability.zero_grad()
r_loss.backward()
if self.learner_group:
for p in self._model.reachability.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, group=self.learner_group)
if self._clip_grad_norm > 0.0:
nn.utils.clip_grad_norm_(
self._model.reachability.parameters(),
self._clip_grad_norm,
)
self._optim.reachability.step()
# Update target network
with th.no_grad():
for tp, p in zip(
self._target.parameters(), self._model.parameters()
):
tp.data.lerp_(p.data, 1.0 - self._polyak)
if task_keys:
task_keys_c = th.cat(task_keys)
tasks, counts = th.unique(
task_keys_c, return_counts=True, sorted=False
)
for t, c in zip(tasks.cpu().numpy(), counts.cpu().numpy()):
self._sampled_tasks[t] += c
# These are the stats for the last update
with th.no_grad():
mean_alpha = np.mean(
[a.exp().item() for a in self._log_alpha.values()]
)
self.tbw_add_scalar('Loss/Policy', pi_loss.item())
self.tbw_add_scalar('Loss/QValue', q_loss.item())
if self._update_reachability and hasattr(self._model, 'reachability'):
self.tbw_add_scalar('Loss/Reachability', r_loss.item())
self.tbw_add_scalar('Health/Entropy', -log_prob.mean())
if self._optim_alpha:
if not self._per_task_alpha:
self.tbw_add_scalar(
'Health/Alpha', self._log_alpha['_'].exp().item()
)
else:
self.tbw_add_scalar('Health/MeanAlpha', mean_alpha)
if self._n_updates % 100 == 1:
self.tbw.add_scalars(
'Health/GradNorms',
{
k: v.grad.norm().item()
for k, v in self._model.named_parameters()
if v.grad is not None
},
self.n_samples,
)
for i in range(a.shape[1]):
self.tbw.add_histogram(
f'Health/PolicyA{i}', a[i][:100], self.n_samples
)
self.tbw.add_histogram('Health/Q1', q1[:100], self.n_samples)
self.tbw.add_histogram('Health/Q2', q2[:100], self.n_samples)
# Log TD errors per abstraction
if 'task_key' in batch and self._n_updates % 10 == 1:
task_key_idx = task_map(task_keys[-1])
with th.no_grad():
tderrs = {}
for key, idx in task_key_idx.items():
task = self._key_to_task[key]
tde1 = q1_loss[idx].sqrt().mean().item()
tde2 = q1_loss[idx].sqrt().mean().item()
tderrs[task] = (tde1 + tde2) / 2
self.tbw.add_scalars(
'Health/AbsTDErrorMean', tderrs, self._n_samples
)
for task, err in tderrs.items():
self._avg_tderr[task] *= 1.0 - self._avg_tderr_alpha
self._avg_tderr[task] += self._avg_tderr_alpha * err
self.tbw.add_scalars(
'Agent/SampledTasks',
{self._key_to_task[k]: v for k, v in self._sampled_tasks.items()},
self._n_samples,
)
self.tbw.add_scalars(
'Agent/SamplesPerTask',
{self._key_to_task[k]: v for k, v in self._task_samples.items()},
self._n_samples,
)
avg_cr = th.cat(self._cur_rewards).mean().item()
if self._update_reachability and hasattr(self._model, 'reachability'):
log.info(
f'Sample {self._n_samples}, up {self._ups}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, r loss {r_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {mean_alpha:.03f}'
)
else:
log.info(
f'Sample {self._n_samples}, up {self._ups}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {mean_alpha:.03f}'
)