def update_parameters()

in activemri/baselines/ddqn.py [0:0]


    def update_parameters(self, target_net: nn.Module) -> Optional[Dict[str, Any]]:
        self.model.train()
        batch = self.memory.sample()
        if batch is None:
            return None
        observations = batch["observations"].to(self.device)
        next_observations = batch["next_observations"].to(self.device)
        actions = batch["actions"].to(self.device)
        rewards = batch["rewards"].to(self.device).squeeze()
        dones = batch["dones"].to(self.device)

        not_done_mask = dones.logical_not().squeeze()

        # Compute Q-values and get best action according to online network
        output_cur_step = self.forward(observations)
        all_q_values_cur = output_cur_step
        q_values = all_q_values_cur.gather(1, actions.unsqueeze(1))

        # Compute target values using the best action found
        if self.opts.gamma == 0.0:
            target_values = rewards
        else:
            with torch.no_grad():
                all_q_values_next = self.forward(next_observations)
                target_values = torch.zeros(observations.shape[0], device=self.device)
                del observations
                if not_done_mask.any().item() != 0:
                    best_actions = all_q_values_next.detach().max(1)[1]
                    target_values[not_done_mask] = (
                        target_net.forward(next_observations)
                        .gather(1, best_actions.unsqueeze(1))[not_done_mask]
                        .squeeze()
                        .detach()
                    )

                target_values = self.opts.gamma * target_values + rewards

        # loss = F.mse_loss(q_values, target_values.unsqueeze(1))
        loss = F.smooth_l1_loss(q_values, target_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()

        # Compute total gradient norm (for logging purposes) and then clip gradients
        grad_norm: torch.Tensor = 0  # type: ignore
        for p in list(filter(lambda p: p.grad is not None, self.parameters())):
            grad_norm += p.grad.data.norm(2).item() ** 2
        grad_norm = grad_norm ** 0.5
        torch.nn.utils.clip_grad_value_(self.parameters(), 1)

        self.optimizer.step()

        torch.cuda.empty_cache()

        return {
            "loss": loss,
            "grad_norm": grad_norm,
            "q_values_mean": q_values.detach().mean().cpu().numpy(),
            "q_values_std": q_values.detach().std().cpu().numpy(),
        }