def _update()

in hucc/agents/sacmt.py [0:0]


    def _update(self):
        for p in self._model.parameters():
            mdevice = p.device
            break

        def act_logp(obs):
            dist = self._model.pi(obs)
            action = dist.rsample()
            log_prob = dist.log_prob(action).sum(dim=-1)
            action = action * self._action_factor
            return action, log_prob

        def task_map(task_keys):
            # Maps task keys in the batch to indices
            keys, pos = th.unique(task_keys, return_inverse=True, sorted=False)
            keys = keys.cpu().numpy()
            pos = pos.cpu().numpy()
            key_idx = defaultdict(list)
            for i, j in enumerate(pos):
                key_idx[keys[j]].append(i)
            return key_idx

        def alpha_for_tasks(task_keys):
            key_idx = task_map(task_keys)
            alpha = th.zeros(task_keys.shape[0], dtype=th.float32)
            for key, idx in key_idx.items():
                alpha[idx] = self._log_alpha[key].detach().exp()
            return alpha.to(task_keys.device), key_idx

        task_keys: List[th.Tensor] = []
        for _ in range(self._num_updates):
            self._ups += 1
            batch = self._buffer.get_batch(self._bsz, device=mdevice)
            obs = {k: batch[f'obs_{k}'] for k in self._obs_keys}
            obs_p = {k: batch[f'next_obs_{k}'] for k in self._obs_keys}
            reward = batch['reward']
            not_done = th.logical_not(batch['terminal'])

            # Determine task counts in this batch
            if self._per_task_alpha:
                with th.no_grad():
                    task_alpha, task_key_idx = alpha_for_tasks(
                        batch['task_key']
                    )
            if 'task_key' in batch:
                task_keys.append(batch['task_key'])

            # Backup for Q-Function
            with th.no_grad():
                a_p, log_prob_p = act_logp(obs_p)

                q_in = dict(action=a_p, **obs_p)
                q_tgt = th.min(self._target.q(q_in), dim=-1).values
                if not self._per_task_alpha:
                    alpha = self._log_alpha['_'].detach().exp()
                else:
                    alpha = task_alpha
                backup = reward + self._gamma * not_done * (
                    q_tgt - alpha * log_prob_p
                )

            # Q-Function update
            q_in = dict(action=batch['action'], **obs)
            q = self._model.q(q_in)
            q1 = q[:, 0]
            q2 = q[:, 1]
            q1_loss = F.mse_loss(q1, backup, reduction='none')
            q2_loss = F.mse_loss(q2, backup, reduction='none')
            self._optim.q.zero_grad()
            q_loss = q1_loss.mean() + q2_loss.mean()
            q_loss.backward()
            if self.learner_group:
                for p in self._model.q.parameters():
                    if p.grad is not None:
                        dist.all_reduce(p.grad, group=self.learner_group)
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.q.parameters(), self._clip_grad_norm
                )
            self._optim.q.step()

            # Policy update
            for param in self._model.q.parameters():
                param.requires_grad_(False)

            a, log_prob = act_logp(obs)
            q_in = dict(action=a, **obs)
            q = th.min(self._model.q(q_in), dim=-1).values
            if not self._per_task_alpha:
                alpha = self._log_alpha['_'].detach().exp()
            else:
                alpha = task_alpha.detach()
            pi_loss = alpha * log_prob - q
            pi_loss = pi_loss.mean()
            self._optim.pi.zero_grad()
            pi_loss.backward()
            if self.learner_group:
                for p in self._model.pi.parameters():
                    if p.grad is not None:
                        dist.all_reduce(p.grad, group=self.learner_group)
            if self._clip_grad_norm > 0.0:
                nn.utils.clip_grad_norm_(
                    self._model.pi.parameters(), self._clip_grad_norm
                )
            self._optim.pi.step()

            for param in self._model.q.parameters():
                param.requires_grad_(True)

            # Optional temperature update
            if self._optim_alpha is not None:

                def optim_alpha(key):
                    if not key in self._optim_alpha:
                        self._optim_alpha[key] = hydra.utils.instantiate(
                            self._cfg_optim_alpha, [self._log_alpha[key]]
                        )
                    return self._optim_alpha[key]

                if not self._per_task_alpha:
                    # This is slight reording of the formulation in
                    # https://github.com/rail-berkeley/softlearning, mostly so we
                    # don't need to create temporary tensors. log_prob is the only
                    # non-scalar tensor, so we can compute its mean first.
                    alpha_loss = -(
                        self._log_alpha['_'].exp()
                        * (
                            log_prob.mean().cpu() + self._target_entropy
                        ).detach()
                    )
                    optim_alpha('_').zero_grad()
                    alpha_loss.backward()
                    optim_alpha('_').step()
                else:
                    for key, idx in task_key_idx.items():
                        alpha_loss = -(
                            self._log_alpha[key].exp()
                            * (
                                log_prob[idx].mean().cpu()
                                + self._target_entropy
                            ).detach()
                        )
                        optim_alpha(key).zero_grad()
                        alpha_loss.backward()
                        optim_alpha(key).step()

            # Update reachability network via TD learning
            if self._update_reachability and hasattr(
                self._model, 'reachability'
            ):
                with th.no_grad():
                    a_p, log_prob_p = act_logp(obs_p)
                    r_in = dict(action=a_p, **obs_p)
                    r_tgt = self._target.reachability(r_in).view(-1)
                    propagate = th.logical_not(batch['last_step_of_task'])
                    r_backup = batch['reached_goal'] + propagate * r_tgt

                r_in = dict(action=batch['action'], **obs)
                r_est = self._model.reachability(r_in).view(-1)
                r_loss = F.mse_loss(r_est, r_backup, reduction='mean')
                self._optim.reachability.zero_grad()
                r_loss.backward()
                if self.learner_group:
                    for p in self._model.reachability.parameters():
                        if p.grad is not None:
                            dist.all_reduce(p.grad, group=self.learner_group)
                if self._clip_grad_norm > 0.0:
                    nn.utils.clip_grad_norm_(
                        self._model.reachability.parameters(),
                        self._clip_grad_norm,
                    )
                self._optim.reachability.step()

            # Update target network
            with th.no_grad():
                for tp, p in zip(
                    self._target.parameters(), self._model.parameters()
                ):
                    tp.data.lerp_(p.data, 1.0 - self._polyak)

        if task_keys:
            task_keys_c = th.cat(task_keys)
            tasks, counts = th.unique(
                task_keys_c, return_counts=True, sorted=False
            )
            for t, c in zip(tasks.cpu().numpy(), counts.cpu().numpy()):
                self._sampled_tasks[t] += c

        # These are the stats for the last update
        with th.no_grad():
            mean_alpha = np.mean(
                [a.exp().item() for a in self._log_alpha.values()]
            )
        self.tbw_add_scalar('Loss/Policy', pi_loss.item())
        self.tbw_add_scalar('Loss/QValue', q_loss.item())
        if self._update_reachability and hasattr(self._model, 'reachability'):
            self.tbw_add_scalar('Loss/Reachability', r_loss.item())
        self.tbw_add_scalar('Health/Entropy', -log_prob.mean())
        if self._optim_alpha:
            if not self._per_task_alpha:
                self.tbw_add_scalar(
                    'Health/Alpha', self._log_alpha['_'].exp().item()
                )
            else:
                self.tbw_add_scalar('Health/MeanAlpha', mean_alpha)
        if self._n_updates % 100 == 1:
            self.tbw.add_scalars(
                'Health/GradNorms',
                {
                    k: v.grad.norm().item()
                    for k, v in self._model.named_parameters()
                    if v.grad is not None
                },
                self.n_samples,
            )
            for i in range(a.shape[1]):
                self.tbw.add_histogram(
                    f'Health/PolicyA{i}', a[i][:100], self.n_samples
                )
            self.tbw.add_histogram('Health/Q1', q1[:100], self.n_samples)
            self.tbw.add_histogram('Health/Q2', q2[:100], self.n_samples)

        # Log TD errors per abstraction
        if 'task_key' in batch and self._n_updates % 10 == 1:
            task_key_idx = task_map(task_keys[-1])
            with th.no_grad():
                tderrs = {}
                for key, idx in task_key_idx.items():
                    task = self._key_to_task[key]
                    tde1 = q1_loss[idx].sqrt().mean().item()
                    tde2 = q1_loss[idx].sqrt().mean().item()
                    tderrs[task] = (tde1 + tde2) / 2
                self.tbw.add_scalars(
                    'Health/AbsTDErrorMean', tderrs, self._n_samples
                )
                for task, err in tderrs.items():
                    self._avg_tderr[task] *= 1.0 - self._avg_tderr_alpha
                    self._avg_tderr[task] += self._avg_tderr_alpha * err

        self.tbw.add_scalars(
            'Agent/SampledTasks',
            {self._key_to_task[k]: v for k, v in self._sampled_tasks.items()},
            self._n_samples,
        )
        self.tbw.add_scalars(
            'Agent/SamplesPerTask',
            {self._key_to_task[k]: v for k, v in self._task_samples.items()},
            self._n_samples,
        )

        avg_cr = th.cat(self._cur_rewards).mean().item()
        if self._update_reachability and hasattr(self._model, 'reachability'):
            log.info(
                f'Sample {self._n_samples}, up {self._ups}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, r loss {r_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {mean_alpha:.03f}'
            )
        else:
            log.info(
                f'Sample {self._n_samples}, up {self._ups}, avg cur reward {avg_cr:+0.3f}, pi loss {pi_loss.item():+.03f}, q loss {q_loss.item():+.03f}, entropy {-log_prob.mean().item():+.03f}, alpha {mean_alpha:.03f}'
            )